Finish porting stream classes to nMigen
authorJean THOMAS <git0@pub.jeanthomas.me>
Wed, 10 Jun 2020 08:27:55 +0000 (10:27 +0200)
committerJean THOMAS <git0@pub.jeanthomas.me>
Wed, 10 Jun 2020 08:27:55 +0000 (10:27 +0200)
gram/stream.py

index 9f30c1ee7a3093907b589c8592955a321a594ba9..c97e5aab38708c73d6166fa4dc690d5b9ad45e6a 100644 (file)
@@ -154,23 +154,25 @@ class _UpConverter(Elaboratable):
         self.source = source = Endpoint(source_layout)
         self.ratio = ratio
         self._nbits_from  = nbits_from
         self.source = source = Endpoint(source_layout)
         self.ratio = ratio
         self._nbits_from  = nbits_from
+        self._reverse = reverse
+        self._report_valid_token_count = report_valid_token_count
 
     def elaborate(self, platform):
         m = Module()
 
         # control path
 
     def elaborate(self, platform):
         m = Module()
 
         # control path
-        demux = Signal(max=ratio)
+        demux = Signal(range(self.ratio))
         load_part = Signal()
         strobe_all = Signal()
         m.d.comb += [
         load_part = Signal()
         strobe_all = Signal()
         m.d.comb += [
-            self.sink.ack.eq(~strobe_all | self.source.ack),
-            self.source.stb.eq(strobe_all),
-            load_part.eq(self.sink.stb & self.sink.ack)
+            self.sink.ready.eq(~strobe_all | self.source.ready),
+            self.source.valid.eq(strobe_all),
+            load_part.eq(self.sink.valid & self.sink.ready)
         ]
 
         ]
 
-        demux_last = ((demux == (ratio - 1)) | self.sink.eop)
+        demux_last = ((demux == (self.ratio - 1)) | self.sink.last)
 
 
-        with m.If(source.ack):
+        with m.If(self.source.ready):
             m.d.sync += strobe_all.eq(0)
 
         with m.If(load_part):
             m.d.sync += strobe_all.eq(0)
 
         with m.If(load_part):
@@ -182,19 +184,20 @@ class _UpConverter(Elaboratable):
             with m.Else():
                 m.d.sync += demux.eq(demux+1)
 
             with m.Else():
                 m.d.sync += demux.eq(demux+1)
 
-        with m.If(self.source.stb & self.source.ack):
-            m.d.sync += self.source.eop.eq(self.sink.eop)
-        with m.Elif(self.sink.stb & self.sink.ack):
-            m.d.sync += self.source.eop.eq(self.sink.eop | self.source.eop)
+        with m.If(self.source.valid & self.source.ready):
+            m.d.sync += self.source.last.eq(self.sink.last)
+        with m.Elif(self.sink.valid & self.sink.ready):
+            m.d.sync += self.source.last.eq(self.sink.last | self.source.last)
 
         # data path
 
         # data path
-        cases = {}
-        for i in range(ratio):
-            n = ratio-i-1 if reverse else i
-            cases[i] = source.data[n*self._nbits_from:(n+1)*self._nbits_from].eq(sink.data)
-        m.d.sync += If(load_part, Case(demux, cases))
+        with m.If(load_part):
+            with m.Switch(demux):
+                for i in range(self.ratio):
+                    with m.Case(i):
+                        n = self.ratio-i-1 if self._reverse else i
+                        m.d.sync += self.source.payload.lower()[n*self._nbits_from:(n+1)*self._nbits_from].eq(self.sink.payload)
 
 
-        if report_valid_token_count:
+        if self._report_valid_token_count:
             with m.If(load_part):
                 m.d.sync += self.source.valid_token_count.eq(demux + 1)
 
             with m.If(load_part):
                 m.d.sync += self.source.valid_token_count.eq(demux + 1)
 
@@ -204,12 +207,15 @@ class _UpConverter(Elaboratable):
 class _DownConverter(Elaboratable):
     def __init__(self, nbits_from, nbits_to, ratio, reverse,
                  report_valid_token_count):
 class _DownConverter(Elaboratable):
     def __init__(self, nbits_from, nbits_to, ratio, reverse,
                  report_valid_token_count):
-        self.sink = sink = Endpoint([("data", nbits_from)])
+        self.sink = Endpoint([("data", nbits_from)])
         source_layout = [("data", nbits_to)]
         if report_valid_token_count:
             source_layout.append(("valid_token_count", 1))
         source_layout = [("data", nbits_to)]
         if report_valid_token_count:
             source_layout.append(("valid_token_count", 1))
-        self.source = source = Endpoint(source_layout)
+        self.source = Endpoint(source_layout)
         self.ratio = ratio
         self.ratio = ratio
+        self._reverse = reverse
+        self._nbits_to = nbits_to
+        self._report_valid_token_count = report_valid_token_count
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -219,24 +225,30 @@ class _DownConverter(Elaboratable):
         last = Signal()
         m.d.comb += [
             last.eq(mux == (self.ratio-1)),
         last = Signal()
         m.d.comb += [
             last.eq(mux == (self.ratio-1)),
-            self.source.stb.eq(self.sink.stb),
-            self.source.eop.eq(sink.eop & last),
-            self.sink.ack.eq(last & self.source.ack)
+            self.source.valid.eq(self.sink.valid),
+            self.source.last.eq(self.sink.last & last),
+            self.sink.ready.eq(last & self.source.ready)
         ]
         ]
-        with m.If(self.source.stb & self.source.ack):
+        with m.If(self.source.valid & self.source.ready):
             with m.If(last):
                 m.d.sync += mux.eq(0)
             with m.Else():
                 m.d.sync += mux.eq(mux+1)
 
         # data path
             with m.If(last):
                 m.d.sync += mux.eq(0)
             with m.Else():
                 m.d.sync += mux.eq(mux+1)
 
         # data path
-        cases = {}
-        for i in range(ratio):
-            n = ratio-i-1 if reverse else i
-            cases[i] = self.source.data.eq(self.sink.data[n*nbits_to:(n+1)*nbits_to])
-        m.d.comb += Case(mux, cases).makedefault()
+        # cases = {}
+        # for i in range(self.ratio):
+        #     n = self.ratio-i-1 if self._reverse else i
+        #     cases[i] = self.source.data.eq(self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to])
+        # m.d.comb += Case(mux, cases).makedefault()
+
+        with m.Switch(mux):
+            for i in range(self.ratio):
+                with m.Case(i):
+                    n = self.ratio-i-1 if self._reverse else i
+                    m.d.comb += self.source.data.eq(self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to])
 
 
-        if report_valid_token_count:
+        if self._report_valid_token_count:
             m.d.comb += self.source.valid_token_count.eq(last)
 
         return m
             m.d.comb += self.source.valid_token_count.eq(last)
 
         return m
@@ -286,7 +298,7 @@ class Converter(Elaboratable):
     def __init__(self, nbits_from, nbits_to, reverse=False,
                  report_valid_token_count=False):
         cls, ratio = _get_converter_ratio(nbits_from, nbits_to)
     def __init__(self, nbits_from, nbits_to, reverse=False,
                  report_valid_token_count=False):
         cls, ratio = _get_converter_ratio(nbits_from, nbits_to)
-        self.specialized,_ = cls(nbits_from, nbits_to, ratio,
+        self.specialized = cls(nbits_from, nbits_to, ratio,
             reverse, report_valid_token_count)
         self.sink = self.specialized.sink
         self.source = self.specialized.source
             reverse, report_valid_token_count)
         self.sink = self.specialized.sink
         self.source = self.specialized.source
@@ -304,6 +316,9 @@ class StrideConverter(Elaboratable):
         self.sink = sink = Endpoint(layout_from)
         self.source = source = Endpoint(layout_to)
 
         self.sink = sink = Endpoint(layout_from)
         self.source = source = Endpoint(layout_to)
 
+        self._layout_to = layout_to
+        self._layout_from = layout_from
+
         nbits_from = len(sink.payload.lower())
         nbits_to = len(source.payload.lower())
         self.converter = Converter(nbits_from, nbits_to, *args, **kwargs)
         nbits_from = len(sink.payload.lower())
         nbits_to = len(source.payload.lower())
         self.converter = Converter(nbits_from, nbits_to, *args, **kwargs)
@@ -314,44 +329,44 @@ class StrideConverter(Elaboratable):
         nbits_from = len(self.sink.payload.lower())
         nbits_to = len(self.source.payload.lower())
         
         nbits_from = len(self.sink.payload.lower())
         nbits_to = len(self.source.payload.lower())
         
-        m.d.submodules += self.converter
+        m.submodules += self.converter
 
         # cast sink to converter.sink (user fields --> raw bits)
         m.d.comb += [
 
         # cast sink to converter.sink (user fields --> raw bits)
         m.d.comb += [
-            self.converter.sink.stb.eq(self.sink.stb),
-            self.converter.sink.eop.eq(self.sink.eop),
-            sink.ack.eq(self.converter.sink.ack)
+            self.converter.sink.valid.eq(self.sink.valid),
+            self.converter.sink.last.eq(self.sink.last),
+            self.sink.ready.eq(self.converter.sink.ready)
         ]
         if isinstance(self.converter.specialized, _DownConverter):
             ratio = self.converter.specialized.ratio
             for i in range(ratio):
                 j = 0
         ]
         if isinstance(self.converter.specialized, _DownConverter):
             ratio = self.converter.specialized.ratio
             for i in range(ratio):
                 j = 0
-                for name, width in layout_to:
+                for name, width in self._layout_to.payload_layout:
                     src = getattr(self.sink, name)[i*width:(i+1)*width]
                     dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
                     m.d.comb += dst.eq(src)
                     j += width
         else:
                     src = getattr(self.sink, name)[i*width:(i+1)*width]
                     dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
                     m.d.comb += dst.eq(src)
                     j += width
         else:
-            m.d.comb += self.converter.sink.data.eq(self.sink.payload.lower())
+            m.d.comb += self.converter.sink.payload.eq(self.sink.payload.lower())
 
 
         # cast converter.source to source (raw bits --> user fields)
         m.d.comb += [
 
 
         # cast converter.source to source (raw bits --> user fields)
         m.d.comb += [
-            self.source.stb.eq(self.converter.source.stb),
-            self.source.eop.eq(self.converter.source.eop),
-            self.converter.source.ack.eq(self.source.ack)
+            self.source.valid.eq(self.converter.source.valid),
+            self.source.last.eq(self.converter.source.last),
+            self.converter.source.ready.eq(self.source.ready)
         ]
         if isinstance(self.converter.specialized, _UpConverter):
             ratio = self.converter.specialized.ratio
             for i in range(ratio):
                 j = 0
         ]
         if isinstance(self.converter.specialized, _UpConverter):
             ratio = self.converter.specialized.ratio
             for i in range(ratio):
                 j = 0
-                for name, width in layout_from:
+                for name, width in self._layout_from.payload_layout:
                     src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
                     dst = getattr(self.source, name)[i*width:(i+1)*width]
                     m.d.comb += dst.eq(src)
                     j += width
         else:
                     src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
                     dst = getattr(self.source, name)[i*width:(i+1)*width]
                     m.d.comb += dst.eq(src)
                     j += width
         else:
-            m.d.comb += self.source.payload.lower().eq(self.converter.source.data)
+            m.d.comb += self.source.payload.lower().eq(self.converter.source.payload)
 
         return m
 
 
         return m
 
@@ -372,8 +387,8 @@ class Pipeline(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
 
-        n = len(modules)
-        mod = modules[0]
+        n = len(self._modules)
+        mod = self._modules[0]
         
         for i in range(1, n):
             mod_n = self._modules[i]
         
         for i in range(1, n):
             mod_n = self._modules[i]
@@ -381,7 +396,7 @@ class Pipeline(Elaboratable):
                 source = mod
             else:
                 source = mod.source
                 source = mod
             else:
                 source = mod.source
-            if isinstance(m_n, Endpoint):
+            if isinstance(mod_n, Endpoint):
                 sink = mod_n
             else:
                 sink = mod_n.sink
                 sink = mod_n
             else:
                 sink = mod_n.sink