soc/interconnect/stream: use new Converter/StrideConverter
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Wed, 16 Mar 2016 16:00:58 +0000 (17:00 +0100)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Wed, 16 Mar 2016 16:00:58 +0000 (17:00 +0100)
litex/soc/interconnect/stream.py

index d69e5a664a8f83d216ca3daec5a30bd1fee6f3e9..f9a3bf33578a542165f423767a175c33424c8c72 100644 (file)
@@ -29,11 +29,11 @@ class EndpointDescription:
             attributed.add(f[0])
 
         full_layout = [
-            ("payload", _make_m2s(self.payload_layout)),
-            ("param", _make_m2s(self.param_layout)),
             ("stb", 1, DIR_M_TO_S),
             ("ack", 1, DIR_S_TO_M),
-            ("eop", 1, DIR_M_TO_S)
+            ("eop", 1, DIR_M_TO_S),
+            ("payload", _make_m2s(self.payload_layout)),
+            ("param", _make_m2s(self.param_layout))
         ]
         return full_layout
 
@@ -62,9 +62,8 @@ class Sink(Endpoint):
 
 class _FIFOWrapper(Module):
     def __init__(self, fifo_class, layout, depth):
-        self.sink = Sink(layout)
-        self.source = Source(layout)
-        self.busy = Signal()
+        self.sink = Endpoint(layout)
+        self.source = Endpoint(layout)
 
         # # #
 
@@ -112,10 +111,10 @@ class AsyncFIFO(_FIFOWrapper):
 
 class Multiplexer(Module):
     def __init__(self, layout, n):
-        self.source = Source(layout)
+        self.source = Endpoint(layout)
         sinks = []
         for i in range(n):
-            sink = Sink(layout)
+            sink = Endpoint(layout)
             setattr(self, "sink"+str(i), sink)
             sinks.append(sink)
         self.sel = Signal(max=n)
@@ -130,10 +129,10 @@ class Multiplexer(Module):
 
 class Demultiplexer(Module):
     def __init__(self, layout, n):
-        self.sink = Sink(layout)
+        self.sink = Endpoint(layout)
         sources = []
         for i in range(n):
-            source = Source(layout)
+            source = Endpoint(layout)
             setattr(self, "source"+str(i), source)
             sources.append(source)
         self.sel = Signal(max=n)
@@ -145,6 +144,206 @@ class Demultiplexer(Module):
             cases[i] = self.sink.connect(source)
         self.comb += Case(self.sel, cases)
 
+
+class _UpConverter(Module):
+    def __init__(self, nbits_from, nbits_to, ratio, reverse):
+        self.sink = sink = Endpoint([("data", nbits_from)])
+        self.source = source = Endpoint([("data", nbits_to),
+                                         ("valid_token_count", bits_for(ratio))])
+        self.latency = 1
+
+        # # #
+
+        # control path
+        demux = Signal(max=ratio)
+        load_part = Signal()
+        strobe_all = Signal()
+        self.comb += [
+            sink.ack.eq(~strobe_all | source.ack),
+            source.stb.eq(strobe_all),
+            load_part.eq(sink.stb & sink.ack)
+        ]
+
+        demux_last = ((demux == (ratio - 1)) | sink.eop)
+
+        self.sync += [
+            If(source.ack, strobe_all.eq(0)),
+            If(load_part,
+                If(demux_last,
+                    demux.eq(0),
+                    strobe_all.eq(1)
+                ).Else(
+                    demux.eq(demux + 1)
+                )
+            ),
+            If(source.stb & source.ack,
+                source.eop.eq(sink.eop),
+            ).Elif(sink.stb & sink.ack,
+                source.eop.eq(sink.eop | source.eop)
+            )
+        ]
+
+        # data path
+        cases = {}
+        for i in range(ratio):
+            n = ratio-i-1 if reverse else i
+            cases[i] = source.data[n*nbits_from:(n+1)*nbits_from].eq(sink.data)
+        self.sync += If(load_part, Case(demux, cases))
+
+        # valid token count
+        self.sync += If(load_part, source.valid_token_count.eq(demux + 1))
+
+
+class _DownConverter(Module):
+    def __init__(self, nbits_from, nbits_to, ratio, reverse):
+        self.sink = sink = Endpoint([("data", nbits_from)])
+        self.source = source = Endpoint([("data", nbits_to),
+                                         ("valid_token_count", 1)])
+        self.latency = 0
+
+        # # #
+
+        # control path
+        mux = Signal(max=ratio)
+        last = Signal()
+        self.comb += [
+            last.eq(mux == (ratio-1)),
+            source.stb.eq(sink.stb),
+            source.eop.eq(sink.eop & last),
+            sink.ack.eq(last & source.ack)
+        ]
+        self.sync += \
+            If(source.stb & source.ack,
+                If(last,
+                    mux.eq(0)
+                ).Else(
+                    mux.eq(mux + 1)
+                )
+            )
+
+        # data path
+        cases = {}
+        for i in range(ratio):
+            n = ratio-i-1 if reverse else i
+            cases[i] = source.data.eq(sink.data[n*nbits_to:(n+1)*nbits_to])
+        self.comb += Case(mux, cases).makedefault()
+
+        # valid token count
+        self.comb += source.valid_token_count.eq(last)
+
+
+class _IdentityConverter(Module):
+    def __init__(self, nbits_from, nbits_to, ratio, reverse):
+        self.sink = sink = Endpoint([("data", nbits_from)])
+        self.source = source = Endpoint([("data", nbits_to),
+                                         ("valid_token_count", 1)])
+        self.latency = 0
+
+        # # #
+
+        self.comb += [
+            sink.connect(source),
+            source.valid_token_count.eq(1)
+        ]
+
+
+def _get_converter_ratio(nbits_from, nbits_to):
+    if nbits_from > nbits_to:
+        converter_cls = _DownConverter
+        if nbits_from % nbits_to:
+            raise ValueError("Ratio must be an int")
+        ratio = nbits_from//nbits_to
+    elif nbits_from < nbits_to:
+        converter_cls = _UpConverter
+        if nbits_to % nbits_from:
+            raise ValueError("Ratio must be an int")
+        ratio = nbits_to//nbits_from
+    else:
+        converter_cls = _IdentityConverter
+        ratio = 1
+
+    return converter_cls, ratio
+
+
+class Converter(Module):
+    def __init__(self, nbits_from, nbits_to, reverse=False,
+        report_valid_token_count=False):
+        self.cls, self.ratio = _get_converter_ratio(nbits_from, nbits_to)
+
+        # # #
+
+        converter = self.cls(nbits_from, nbits_to, self.ratio, reverse)
+        self.submodules += converter
+        self.latency = converter.latency
+
+        self.sink = converter.sink
+        if report_valid_token_count:
+            self.source = converter.source
+        else:
+            self.source = Endpoint([("data", nbits_to)])
+            self.comb += converter.source.connect(self.source,
+                            leave_out=set(["valid_token_count"]))
+
+
+class StrideConverter(Module):
+    def __init__(self, description_from, description_to, reverse=False):
+        self.sink = sink = Endpoint(description_from)
+        self.source = source = Endpoint(description_to)
+
+        # # #
+
+        nbits_from = len(sink.payload.raw_bits())
+        nbits_to = len(source.payload.raw_bits())
+
+        converter = Converter(nbits_from, nbits_to, reverse)
+        self.submodules += converter
+
+        # cast sink to converter.sink (user fields --> raw bits)
+        self.comb += [
+            converter.sink.stb.eq(sink.stb),
+            converter.sink.eop.eq(sink.eop),
+            sink.ack.eq(converter.sink.ack)
+        ]
+        if converter.cls == _DownConverter:
+            ratio = converter.ratio
+            for i in range(ratio):
+                j = 0
+                for name, width in source.description.payload_layout:
+                    src = getattr(sink, name)[i*width:(i+1)*width]
+                    dst = converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
+                    self.comb += dst.eq(src)
+                    j += width
+        else:
+            self.comb += converter.sink.data.eq(sink.payload.raw_bits())
+
+
+        # cast converter.source to source (raw bits --> user fields)
+        self.comb += [
+            source.stb.eq(converter.source.stb),
+            source.eop.eq(converter.source.eop),
+            converter.source.ack.eq(source.ack)
+        ]
+        if converter.cls == _UpConverter:
+            ratio = converter.ratio
+            for i in range(ratio):
+                j = 0
+                for name, width in sink.description.payload_layout:
+                    src = converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
+                    dst = getattr(source, name)[i*width:(i+1)*width]
+                    self.comb += dst.eq(src)
+                    j += width
+        else:
+            self.comb += source.payload.raw_bits().eq(converter.source.data)
+
+        # connect params
+        if converter.latency == 0:
+            self.comb += source.param.eq(sink.param)
+        elif converter.latency == 1:
+            self.sync += source.param.eq(sink.param)
+        else:
+            raise ValueError
+
+
 # TODO: clean up code below
 # XXX
 
@@ -351,99 +550,6 @@ class Pack(Module):
         ]
 
 
-class Chunkerize(CombinatorialActor):
-    def __init__(self, layout_from, layout_to, n, reverse=False):
-        self.sink = Sink(layout_from)
-        if isinstance(layout_to, EndpointDescription):
-            layout_to = copy(layout_to)
-            layout_to.payload_layout = pack_layout(layout_to.payload_layout, n)
-        else:
-            layout_to = pack_layout(layout_to, n)
-        self.source = Source(layout_to)
-        CombinatorialActor.__init__(self)
-
-        # # #
-
-        for i in range(n):
-            chunk = n-i-1 if reverse else i
-            for f in self.sink.description.payload_layout:
-                src = getattr(self.sink, f[0])
-                dst = getattr(getattr(self.source, "chunk"+str(chunk)), f[0])
-                self.comb += dst.eq(src[i*len(src)//n:(i+1)*len(src)//n])
-
-        for f in self.sink.description.param_layout:
-            src = getattr(self.sink, f[0])
-            dst = getattr(self.source, f[0])
-            self.comb += dst.eq(src)
-
-
-class Unchunkerize(CombinatorialActor):
-    def __init__(self, layout_from, n, layout_to, reverse=False):
-        if isinstance(layout_from, EndpointDescription):
-            fields = layout_from.payload_layout
-            layout_from = copy(layout_from)
-            layout_from.payload_layout = pack_layout(layout_from.payload_layout, n)
-        else:
-            fields = layout_from
-            layout_from = pack_layout(layout_from, n)
-        self.sink = Sink(layout_from)
-        self.source = Source(layout_to)
-        CombinatorialActor.__init__(self)
-
-        # # #
-
-        for i in range(n):
-            chunk = n-i-1 if reverse else i
-            for f in fields:
-                src = getattr(getattr(self.sink, "chunk"+str(chunk)), f[0])
-                dst = getattr(self.source, f[0])
-                self.comb += dst[i*len(dst)//n:(i+1)*len(dst)//n].eq(src)
-
-        for f in self.sink.description.param_layout:
-            src = getattr(self.sink, f[0])
-            dst = getattr(self.source, f[0])
-            self.comb += dst.eq(src)
-
-
-class Converter(Module):
-    def __init__(self, layout_from, layout_to, reverse=False):
-        self.sink = Sink(layout_from)
-        self.source = Source(layout_to)
-
-        # # #
-
-        width_from = len(self.sink.payload.raw_bits())
-        width_to = len(self.source.payload.raw_bits())
-
-        # downconverter
-        if width_from > width_to:
-            if width_from % width_to:
-                raise ValueError
-            ratio = width_from//width_to
-            self.submodules.chunkerize = Chunkerize(layout_from, layout_to, ratio, reverse)
-            self.submodules.unpack = Unpack(ratio, layout_to)
-
-            self.submodules += Pipeline(self.sink,
-                                        self.chunkerize,
-                                        self.unpack,
-                                        self.source)
-        # upconverter
-        elif width_to > width_from:
-            if width_to % width_from:
-                raise ValueError
-            ratio = width_to//width_from
-            self.submodules.pack = Pack(layout_from, ratio)
-            self.submodules.unchunkerize = Unchunkerize(layout_from, ratio, layout_to, reverse)
-
-            self.submodules += Pipeline(self.sink,
-                                        self.pack,
-                                        self.unchunkerize,
-                                        self.source)
-        # direct connection
-        else:
-            self.comb += self.sink.connect(self.source)
-
-
 class Pipeline(Module):
     def __init__(self, *modules):
         n = len(modules)