From 5eb455681447abe04a337bf65fe37fa1785a61e0 Mon Sep 17 00:00:00 2001 From: Jean THOMAS Date: Wed, 10 Jun 2020 10:27:55 +0200 Subject: [PATCH] Finish porting stream classes to nMigen --- gram/stream.py | 101 ++++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 43 deletions(-) diff --git a/gram/stream.py b/gram/stream.py index 9f30c1e..c97e5aa 100644 --- a/gram/stream.py +++ b/gram/stream.py @@ -154,23 +154,25 @@ class _UpConverter(Elaboratable): self.source = source = Endpoint(source_layout) self.ratio = ratio self._nbits_from = nbits_from + self._reverse = reverse + self._report_valid_token_count = report_valid_token_count def elaborate(self, platform): m = Module() # control path - demux = Signal(max=ratio) + demux = Signal(range(self.ratio)) load_part = Signal() strobe_all = Signal() m.d.comb += [ - self.sink.ack.eq(~strobe_all | self.source.ack), - self.source.stb.eq(strobe_all), - load_part.eq(self.sink.stb & self.sink.ack) + self.sink.ready.eq(~strobe_all | self.source.ready), + self.source.valid.eq(strobe_all), + load_part.eq(self.sink.valid & self.sink.ready) ] - demux_last = ((demux == (ratio - 1)) | self.sink.eop) + demux_last = ((demux == (self.ratio - 1)) | self.sink.last) - with m.If(source.ack): + with m.If(self.source.ready): m.d.sync += strobe_all.eq(0) with m.If(load_part): @@ -182,19 +184,20 @@ class _UpConverter(Elaboratable): with m.Else(): m.d.sync += demux.eq(demux+1) - with m.If(self.source.stb & self.source.ack): - m.d.sync += self.source.eop.eq(self.sink.eop) - with m.Elif(self.sink.stb & self.sink.ack): - m.d.sync += self.source.eop.eq(self.sink.eop | self.source.eop) + with m.If(self.source.valid & self.source.ready): + m.d.sync += self.source.last.eq(self.sink.last) + with m.Elif(self.sink.valid & self.sink.ready): + m.d.sync += self.source.last.eq(self.sink.last | self.source.last) # data path - cases = {} - for i in range(ratio): - n = ratio-i-1 if reverse else i - cases[i] = source.data[n*self._nbits_from:(n+1)*self._nbits_from].eq(sink.data) - m.d.sync += If(load_part, Case(demux, cases)) + with m.If(load_part): + with m.Switch(demux): + for i in range(self.ratio): + with m.Case(i): + n = self.ratio-i-1 if self._reverse else i + m.d.sync += self.source.payload.lower()[n*self._nbits_from:(n+1)*self._nbits_from].eq(self.sink.payload) - if report_valid_token_count: + if self._report_valid_token_count: with m.If(load_part): m.d.sync += self.source.valid_token_count.eq(demux + 1) @@ -204,12 +207,15 @@ class _UpConverter(Elaboratable): class _DownConverter(Elaboratable): def __init__(self, nbits_from, nbits_to, ratio, reverse, report_valid_token_count): - self.sink = sink = Endpoint([("data", nbits_from)]) + self.sink = Endpoint([("data", nbits_from)]) source_layout = [("data", nbits_to)] if report_valid_token_count: source_layout.append(("valid_token_count", 1)) - self.source = source = Endpoint(source_layout) + self.source = Endpoint(source_layout) self.ratio = ratio + self._reverse = reverse + self._nbits_to = nbits_to + self._report_valid_token_count = report_valid_token_count def elaborate(self, platform): m = Module() @@ -219,24 +225,30 @@ class _DownConverter(Elaboratable): last = Signal() m.d.comb += [ last.eq(mux == (self.ratio-1)), - self.source.stb.eq(self.sink.stb), - self.source.eop.eq(sink.eop & last), - self.sink.ack.eq(last & self.source.ack) + self.source.valid.eq(self.sink.valid), + self.source.last.eq(self.sink.last & last), + self.sink.ready.eq(last & self.source.ready) ] - with m.If(self.source.stb & self.source.ack): + with m.If(self.source.valid & self.source.ready): with m.If(last): m.d.sync += mux.eq(0) with m.Else(): m.d.sync += mux.eq(mux+1) # data path - cases = {} - for i in range(ratio): - n = ratio-i-1 if reverse else i - cases[i] = self.source.data.eq(self.sink.data[n*nbits_to:(n+1)*nbits_to]) - m.d.comb += Case(mux, cases).makedefault() + # cases = {} + # for i in range(self.ratio): + # n = self.ratio-i-1 if self._reverse else i + # cases[i] = self.source.data.eq(self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to]) + # m.d.comb += Case(mux, cases).makedefault() + + with m.Switch(mux): + for i in range(self.ratio): + with m.Case(i): + n = self.ratio-i-1 if self._reverse else i + m.d.comb += self.source.data.eq(self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to]) - if report_valid_token_count: + if self._report_valid_token_count: m.d.comb += self.source.valid_token_count.eq(last) return m @@ -286,7 +298,7 @@ class Converter(Elaboratable): def __init__(self, nbits_from, nbits_to, reverse=False, report_valid_token_count=False): cls, ratio = _get_converter_ratio(nbits_from, nbits_to) - self.specialized,_ = cls(nbits_from, nbits_to, ratio, + self.specialized = cls(nbits_from, nbits_to, ratio, reverse, report_valid_token_count) self.sink = self.specialized.sink self.source = self.specialized.source @@ -304,6 +316,9 @@ class StrideConverter(Elaboratable): self.sink = sink = Endpoint(layout_from) self.source = source = Endpoint(layout_to) + self._layout_to = layout_to + self._layout_from = layout_from + nbits_from = len(sink.payload.lower()) nbits_to = len(source.payload.lower()) self.converter = Converter(nbits_from, nbits_to, *args, **kwargs) @@ -314,44 +329,44 @@ class StrideConverter(Elaboratable): nbits_from = len(self.sink.payload.lower()) nbits_to = len(self.source.payload.lower()) - m.d.submodules += self.converter + m.submodules += self.converter # cast sink to converter.sink (user fields --> raw bits) m.d.comb += [ - self.converter.sink.stb.eq(self.sink.stb), - self.converter.sink.eop.eq(self.sink.eop), - sink.ack.eq(self.converter.sink.ack) + self.converter.sink.valid.eq(self.sink.valid), + self.converter.sink.last.eq(self.sink.last), + self.sink.ready.eq(self.converter.sink.ready) ] if isinstance(self.converter.specialized, _DownConverter): ratio = self.converter.specialized.ratio for i in range(ratio): j = 0 - for name, width in layout_to: + for name, width in self._layout_to.payload_layout: src = getattr(self.sink, name)[i*width:(i+1)*width] dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width] m.d.comb += dst.eq(src) j += width else: - m.d.comb += self.converter.sink.data.eq(self.sink.payload.lower()) + m.d.comb += self.converter.sink.payload.eq(self.sink.payload.lower()) # cast converter.source to source (raw bits --> user fields) m.d.comb += [ - self.source.stb.eq(self.converter.source.stb), - self.source.eop.eq(self.converter.source.eop), - self.converter.source.ack.eq(self.source.ack) + self.source.valid.eq(self.converter.source.valid), + self.source.last.eq(self.converter.source.last), + self.converter.source.ready.eq(self.source.ready) ] if isinstance(self.converter.specialized, _UpConverter): ratio = self.converter.specialized.ratio for i in range(ratio): j = 0 - for name, width in layout_from: + for name, width in self._layout_from.payload_layout: src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width] dst = getattr(self.source, name)[i*width:(i+1)*width] m.d.comb += dst.eq(src) j += width else: - m.d.comb += self.source.payload.lower().eq(self.converter.source.data) + m.d.comb += self.source.payload.lower().eq(self.converter.source.payload) return m @@ -372,8 +387,8 @@ class Pipeline(Elaboratable): def elaborate(self, platform): m = Module() - n = len(modules) - mod = modules[0] + n = len(self._modules) + mod = self._modules[0] for i in range(1, n): mod_n = self._modules[i] @@ -381,7 +396,7 @@ class Pipeline(Elaboratable): source = mod else: source = mod.source - if isinstance(m_n, Endpoint): + if isinstance(mod_n, Endpoint): sink = mod_n else: sink = mod_n.sink -- 2.30.2