Add StrideConverter implementation
authorJean THOMAS <git0@pub.jeanthomas.me>
Tue, 9 Jun 2020 15:32:09 +0000 (17:32 +0200)
committerJean THOMAS <git0@pub.jeanthomas.me>
Tue, 9 Jun 2020 15:32:09 +0000 (17:32 +0200)
gram/stream.py

index 675af64751940a75e8eb4066d33a1c0765794106..e79a4e4c9cd1ebbbbea265010d7f590853fc1f52 100644 (file)
@@ -5,7 +5,7 @@ from nmigen.hdl.rec import *
 from nmigen.lib import fifo
 
 
-__all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer"]
+__all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer", "StrideConverter"]
 
 
 def _make_fanout(layout):
@@ -142,3 +142,215 @@ class PipeValid(Elaboratable):
 
 class Buffer(PipeValid):
     pass  # FIXME: Replace Buffer with PipeValid in codebase?
+
+
+class _UpConverter(Elaboratable):
+    def __init__(self, nbits_from, nbits_to, ratio, reverse,
+                 report_valid_token_count):
+        self.sink = sink = Endpoint([("data", nbits_from)])
+        source_layout = [("data", nbits_to)]
+        if report_valid_token_count:
+            source_layout.append(("valid_token_count", bits_for(ratio)))
+        self.source = source = Endpoint(source_layout)
+        self.ratio = ratio
+        self._nbits_from  = nbits_from
+
+    def elaborate(self, platform):
+        m = Module()
+
+        # control path
+        demux = Signal(max=ratio)
+        load_part = Signal()
+        strobe_all = Signal()
+        m.d.comb += [
+            self.sink.ack.eq(~strobe_all | self.source.ack),
+            self.source.stb.eq(strobe_all),
+            load_part.eq(self.sink.stb & self.sink.ack)
+        ]
+
+        demux_last = ((demux == (ratio - 1)) | self.sink.eop)
+
+        with m.If(source.ack):
+            m.d.sync += strobe_all.eq(0)
+
+        with m.If(load_part):
+            with m.If(demux_last):
+                m.d.sync += [
+                    demux.eq(0),
+                    strobe_all.eq(1),
+                ]
+            with m.Else():
+                m.d.sync += demux.eq(demux+1)
+
+        with m.If(self.source.stb & self.source.ack):
+            m.d.sync += self.source.eop.eq(self.sink.eop)
+        with m.Elif(self.sink.stb & self.sink.ack):
+            m.d.sync += self.source.eop.eq(self.sink.eop | self.source.eop)
+
+        # data path
+        cases = {}
+        for i in range(ratio):
+            n = ratio-i-1 if reverse else i
+            cases[i] = source.data[n*self._nbits_from:(n+1)*self._nbits_from].eq(sink.data)
+        m.d.sync += If(load_part, Case(demux, cases))
+
+        if report_valid_token_count:
+            with m.If(load_part):
+                m.d.sync += self.source.valid_token_count.eq(demux + 1)
+
+        return m
+
+
+class _DownConverter(Elaboratable):
+    def __init__(self, nbits_from, nbits_to, ratio, reverse,
+                 report_valid_token_count):
+        self.sink = sink = Endpoint([("data", nbits_from)])
+        source_layout = [("data", nbits_to)]
+        if report_valid_token_count:
+            source_layout.append(("valid_token_count", 1))
+        self.source = source = Endpoint(source_layout)
+        self.ratio = ratio
+
+    def elaborate(self, platform):
+        m = Module()
+
+        # control path
+        mux = Signal(range(self.ratio))
+        last = Signal()
+        m.d.comb += [
+            last.eq(mux == (self.ratio-1)),
+            self.source.stb.eq(self.sink.stb),
+            self.source.eop.eq(sink.eop & last),
+            self.sink.ack.eq(last & self.source.ack)
+        ]
+        with m.If(self.source.stb & self.source.ack):
+            with m.If(last):
+                m.d.sync += mux.eq(0)
+            with m.Else():
+                m.d.sync += mux.eq(mux+1)
+
+        # data path
+        cases = {}
+        for i in range(ratio):
+            n = ratio-i-1 if reverse else i
+            cases[i] = self.source.data.eq(self.sink.data[n*nbits_to:(n+1)*nbits_to])
+        m.d.comb += Case(mux, cases).makedefault()
+
+        if report_valid_token_count:
+            m.d.comb += self.source.valid_token_count.eq(last)
+
+        return m
+
+
+class _IdentityConverter(Elaboratable):
+    def __init__(self, nbits_from, nbits_to, ratio, reverse,
+                 report_valid_token_count):
+        self.sink = Endpoint([("data", nbits_from)])
+        source_layout = [("data", nbits_to)]
+        if report_valid_token_count:
+            source_layout.append(("valid_token_count", 1))
+        self.source = Endpoint(source_layout)
+        assert ratio == 1
+        self.ratio = ratio
+        self._report_valid_token_count = report_valid_token_count
+
+    def elaborate(self, platform):
+        m = Module()
+
+        m.d.comb += self.sink.connect(self.source)
+        if self._report_valid_token_count:
+            m.d.comb += self.source.valid_token_count.eq(1)
+
+        return m
+
+
+def _get_converter_ratio(nbits_from, nbits_to):
+    if nbits_from > nbits_to:
+        specialized_cls = _DownConverter
+        if nbits_from % nbits_to:
+            raise ValueError("Ratio must be an int")
+        ratio = nbits_from//nbits_to
+    elif nbits_from < nbits_to:
+        specialized_cls = _UpConverter
+        if nbits_to % nbits_from:
+            raise ValueError("Ratio must be an int")
+        ratio = nbits_to//nbits_from
+    else:
+        specialized_cls = _IdentityConverter
+        ratio = 1
+
+    return specialized_cls, ratio
+
+
+class Converter(Elaboratable):
+    def __init__(self, nbits_from, nbits_to, reverse=False,
+                 report_valid_token_count=False):
+        cls, ratio = _get_converter_ratio(nbits_from, nbits_to)
+        self.specialized = cls(nbits_from, nbits_to, ratio,
+            reverse, report_valid_token_count)
+        self.sink = self.specialized.sink
+        self.source = self.specialized.source
+
+    def elaborate(self, platform):
+        m = Module()
+
+        m.submodules += self.specialized
+
+        return m
+
+
+class StrideConverter(Elaboratable):
+    def __init__(self, layout_from, layout_to, *args, **kwargs):
+        self.sink = sink = Endpoint(layout_from)
+        self.source = source = Endpoint(layout_to)
+
+        nbits_from = len(sink.payload.raw_bits())
+        nbits_to = len(source.payload.raw_bits())
+        self.converter = Converter(nbits_from, nbits_to, *args, **kwargs)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        nbits_from = len(sink.payload.raw_bits())
+        nbits_to = len(source.payload.raw_bits())
+        
+        m.d.submodules += self.converter
+
+        # cast sink to converter.sink (user fields --> raw bits)
+        m.d.comb += [
+            self.converter.sink.stb.eq(sink.stb),
+            self.converter.sink.eop.eq(sink.eop),
+            sink.ack.eq(self.converter.sink.ack)
+        ]
+        if isinstance(self.converter.specialized, _DownConverter):
+            ratio = self.converter.specialized.ratio
+            for i in range(ratio):
+                j = 0
+                for name, width in layout_to:
+                    src = getattr(sink, name)[i*width:(i+1)*width]
+                    dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
+                    self.comb += dst.eq(src)
+                    j += width
+        else:
+            m.d.comb += self.converter.sink.data.eq(sink.payload.raw_bits())
+
+
+        # cast converter.source to source (raw bits --> user fields)
+        m.D.comb += [
+            source.stb.eq(self.converter.source.stb),
+            source.eop.eq(self.converter.source.eop),
+            self.converter.source.ack.eq(source.ack)
+        ]
+        if isinstance(self.converter.specialized, _UpConverter):
+            ratio = self.converter.specialized.ratio
+            for i in range(ratio):
+                j = 0
+                for name, width in layout_from:
+                    src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
+                    dst = getattr(source, name)[i*width:(i+1)*width]
+                    m.d.comb += dst.eq(src)
+                    j += width
+        else:
+            m.d.comb += source.payload.raw_bits().eq(self.converter.source.data)
+
+        return m