From 86384491e336fbfd5ab439bdc7e8b33ceac068ac Mon Sep 17 00:00:00 2001 From: Jean THOMAS Date: Tue, 9 Jun 2020 17:32:09 +0200 Subject: [PATCH] Add StrideConverter implementation --- gram/stream.py | 214 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 213 insertions(+), 1 deletion(-) diff --git a/gram/stream.py b/gram/stream.py index 675af64..e79a4e4 100644 --- a/gram/stream.py +++ b/gram/stream.py @@ -5,7 +5,7 @@ from nmigen.hdl.rec import * from nmigen.lib import fifo -__all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer"] +__all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer", "StrideConverter"] def _make_fanout(layout): @@ -142,3 +142,215 @@ class PipeValid(Elaboratable): class Buffer(PipeValid): pass # FIXME: Replace Buffer with PipeValid in codebase? + + +class _UpConverter(Elaboratable): + def __init__(self, nbits_from, nbits_to, ratio, reverse, + report_valid_token_count): + self.sink = sink = Endpoint([("data", nbits_from)]) + source_layout = [("data", nbits_to)] + if report_valid_token_count: + source_layout.append(("valid_token_count", bits_for(ratio))) + self.source = source = Endpoint(source_layout) + self.ratio = ratio + self._nbits_from = nbits_from + + def elaborate(self, platform): + m = Module() + + # control path + demux = Signal(max=ratio) + load_part = Signal() + strobe_all = Signal() + m.d.comb += [ + self.sink.ack.eq(~strobe_all | self.source.ack), + self.source.stb.eq(strobe_all), + load_part.eq(self.sink.stb & self.sink.ack) + ] + + demux_last = ((demux == (ratio - 1)) | self.sink.eop) + + with m.If(source.ack): + m.d.sync += strobe_all.eq(0) + + with m.If(load_part): + with m.If(demux_last): + m.d.sync += [ + demux.eq(0), + strobe_all.eq(1), + ] + with m.Else(): + m.d.sync += demux.eq(demux+1) + + with m.If(self.source.stb & self.source.ack): + m.d.sync += self.source.eop.eq(self.sink.eop) + with m.Elif(self.sink.stb & self.sink.ack): + m.d.sync += self.source.eop.eq(self.sink.eop | self.source.eop) + + # data path + cases = {} + for i in range(ratio): + n = ratio-i-1 if reverse else i + cases[i] = source.data[n*self._nbits_from:(n+1)*self._nbits_from].eq(sink.data) + m.d.sync += If(load_part, Case(demux, cases)) + + if report_valid_token_count: + with m.If(load_part): + m.d.sync += self.source.valid_token_count.eq(demux + 1) + + return m + + +class _DownConverter(Elaboratable): + def __init__(self, nbits_from, nbits_to, ratio, reverse, + report_valid_token_count): + self.sink = sink = Endpoint([("data", nbits_from)]) + source_layout = [("data", nbits_to)] + if report_valid_token_count: + source_layout.append(("valid_token_count", 1)) + self.source = source = Endpoint(source_layout) + self.ratio = ratio + + def elaborate(self, platform): + m = Module() + + # control path + mux = Signal(range(self.ratio)) + last = Signal() + m.d.comb += [ + last.eq(mux == (self.ratio-1)), + self.source.stb.eq(self.sink.stb), + self.source.eop.eq(sink.eop & last), + self.sink.ack.eq(last & self.source.ack) + ] + with m.If(self.source.stb & self.source.ack): + with m.If(last): + m.d.sync += mux.eq(0) + with m.Else(): + m.d.sync += mux.eq(mux+1) + + # data path + cases = {} + for i in range(ratio): + n = ratio-i-1 if reverse else i + cases[i] = self.source.data.eq(self.sink.data[n*nbits_to:(n+1)*nbits_to]) + m.d.comb += Case(mux, cases).makedefault() + + if report_valid_token_count: + m.d.comb += self.source.valid_token_count.eq(last) + + return m + + +class _IdentityConverter(Elaboratable): + def __init__(self, nbits_from, nbits_to, ratio, reverse, + report_valid_token_count): + self.sink = Endpoint([("data", nbits_from)]) + source_layout = [("data", nbits_to)] + if report_valid_token_count: + source_layout.append(("valid_token_count", 1)) + self.source = Endpoint(source_layout) + assert ratio == 1 + self.ratio = ratio + self._report_valid_token_count = report_valid_token_count + + def elaborate(self, platform): + m = Module() + + m.d.comb += self.sink.connect(self.source) + if self._report_valid_token_count: + m.d.comb += self.source.valid_token_count.eq(1) + + return m + + +def _get_converter_ratio(nbits_from, nbits_to): + if nbits_from > nbits_to: + specialized_cls = _DownConverter + if nbits_from % nbits_to: + raise ValueError("Ratio must be an int") + ratio = nbits_from//nbits_to + elif nbits_from < nbits_to: + specialized_cls = _UpConverter + if nbits_to % nbits_from: + raise ValueError("Ratio must be an int") + ratio = nbits_to//nbits_from + else: + specialized_cls = _IdentityConverter + ratio = 1 + + return specialized_cls, ratio + + +class Converter(Elaboratable): + def __init__(self, nbits_from, nbits_to, reverse=False, + report_valid_token_count=False): + cls, ratio = _get_converter_ratio(nbits_from, nbits_to) + self.specialized = cls(nbits_from, nbits_to, ratio, + reverse, report_valid_token_count) + self.sink = self.specialized.sink + self.source = self.specialized.source + + def elaborate(self, platform): + m = Module() + + m.submodules += self.specialized + + return m + + +class StrideConverter(Elaboratable): + def __init__(self, layout_from, layout_to, *args, **kwargs): + self.sink = sink = Endpoint(layout_from) + self.source = source = Endpoint(layout_to) + + nbits_from = len(sink.payload.raw_bits()) + nbits_to = len(source.payload.raw_bits()) + self.converter = Converter(nbits_from, nbits_to, *args, **kwargs) + + def elaborate(self, platform): + m = Module() + + nbits_from = len(sink.payload.raw_bits()) + nbits_to = len(source.payload.raw_bits()) + + m.d.submodules += self.converter + + # cast sink to converter.sink (user fields --> raw bits) + m.d.comb += [ + self.converter.sink.stb.eq(sink.stb), + self.converter.sink.eop.eq(sink.eop), + sink.ack.eq(self.converter.sink.ack) + ] + if isinstance(self.converter.specialized, _DownConverter): + ratio = self.converter.specialized.ratio + for i in range(ratio): + j = 0 + for name, width in layout_to: + src = getattr(sink, name)[i*width:(i+1)*width] + dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width] + self.comb += dst.eq(src) + j += width + else: + m.d.comb += self.converter.sink.data.eq(sink.payload.raw_bits()) + + + # cast converter.source to source (raw bits --> user fields) + m.D.comb += [ + source.stb.eq(self.converter.source.stb), + source.eop.eq(self.converter.source.eop), + self.converter.source.ack.eq(source.ack) + ] + if isinstance(self.converter.specialized, _UpConverter): + ratio = self.converter.specialized.ratio + for i in range(ratio): + j = 0 + for name, width in layout_from: + src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width] + dst = getattr(source, name)[i*width:(i+1)*width] + m.d.comb += dst.eq(src) + j += width + else: + m.d.comb += source.payload.raw_bits().eq(self.converter.source.data) + + return m -- 2.30.2