Remove useless variables in _Steerer, ensure command array has 4 elements
[gram.git] / gram / core / multiplexer.py
index 3a9203d830b28a02ebbe3af3552bee3937211388..b813b807f8c7ba2e7cc6d96c90650e8afc8e1924 100644 (file)
@@ -12,14 +12,11 @@ from operator import or_, and_
 
 from nmigen import *
 
-from lambdasoc.periph import Peripheral
-
 from gram.common import *
-from gram.core.bandwidth import Bandwidth
 import gram.stream as stream
 from gram.compat import RoundRobin, delayed_enter
 
-# _CommandChooser ----------------------------------------------------------------------------------
+__ALL__ = ["Multiplexer"]
 
 class _CommandChooser(Elaboratable):
     """Arbitrates between requests, filtering them based on their type
@@ -45,18 +42,20 @@ class _CommandChooser(Elaboratable):
     cmd : Endpoint(cmd_request_rw_layout)
         Currently selected request stream (when ~cmd.valid, cas/ras/we are 0)
     """
+
     def __init__(self, requests):
-        self.want_reads     = Signal()
-        self.want_writes    = Signal()
-        self.want_cmds      = Signal()
+        self.want_reads = Signal()
+        self.want_writes = Signal()
+        self.want_cmds = Signal()
         self.want_activates = Signal()
 
         self._requests = requests
-        a  = len(requests[0].a)
+        a = len(requests[0].a)
         ba = len(requests[0].ba)
 
         # cas/ras/we are 0 when valid is inactive
         self.cmd = stream.Endpoint(cmd_request_rw_layout(a, ba))
+        self.ready = Signal(len(requests))
 
     def elaborate(self, platform):
         m = Module()
@@ -69,11 +68,10 @@ class _CommandChooser(Elaboratable):
             command = request.is_cmd & self.want_cmds & (~is_act_cmd | self.want_activates)
             read = request.is_read == self.want_reads
             write = request.is_write == self.want_writes
-            self.comb += valids[i].eq(request.valid & (command | (read & write)))
-
+            m.d.comb += valids[i].eq(request.valid & (command | (read & write)))
 
-        arbiter = RoundRobin(n, SP_CE)
-        self.submodules += arbiter
+        arbiter = RoundRobin(n)
+        m.submodules += arbiter
         choices = Array(valids[i] for i in range(n))
         m.d.comb += [
             arbiter.request.eq(valids),
@@ -82,21 +80,21 @@ class _CommandChooser(Elaboratable):
 
         for name in ["a", "ba", "is_read", "is_write", "is_cmd"]:
             choices = Array(getattr(req, name) for req in self._requests)
-            self.comb += getattr(self.cmd, name).eq(choices[arbiter.grant])
+            m.d.comb += getattr(self.cmd, name).eq(choices[arbiter.grant])
 
         for name in ["cas", "ras", "we"]:
             # we should only assert those signals when valid is 1
             choices = Array(getattr(req, name) for req in self._requests)
             with m.If(self.cmd.valid):
-                m.d.comb += getattr(cmd, name).eq(choices[arbiter.grant])
+                m.d.comb += getattr(self.cmd, name).eq(choices[arbiter.grant])
 
         for i, request in enumerate(self._requests):
             with m.If(self.cmd.valid & self.cmd.ready & (arbiter.grant == i)):
-                m.d.comb += request.ready.eq(1)
+                m.d.comb += self.ready[i].eq(1)
 
         # Arbitrate if a command is being accepted or if the command is not valid to ensure a valid
         # command is selected when cmd.ready goes high.
-        m.d.comb += arbiter.ce.eq(self.cmd.ready | ~self.cmd.valid)
+        m.d.comb += arbiter.stb.eq(self.cmd.ready | ~self.cmd.valid)
 
         return m
 
@@ -115,8 +113,10 @@ class _CommandChooser(Elaboratable):
 
 # _Steerer -----------------------------------------------------------------------------------------
 
+
 (STEER_NOP, STEER_CMD, STEER_REQ, STEER_REFRESH) = range(4)
 
+
 class _Steerer(Elaboratable):
     """Connects selected request to DFI interface
 
@@ -141,12 +141,12 @@ class _Steerer(Elaboratable):
         DFI phase. The signals should take one of the values from STEER_* to
         select given source.
     """
+
     def __init__(self, commands, dfi):
+        assert len(commands) == 4
         self._commands = commands
         self._dfi = dfi
-        ncmd = len(commands)
-        nph  = len(dfi.phases)
-        self.sel = [Signal(range(ncmd)) for i in range(nph)]
+        self.sel = [Signal(range(len(commands))) for i in range(len(dfi.phases))]
 
     def elaborate(self, platform):
         m = Module()
@@ -161,26 +161,28 @@ class _Steerer(Elaboratable):
                 return cmd.valid & cmd.ready & getattr(cmd, attr)
 
         for i, (phase, sel) in enumerate(zip(dfi.phases, self.sel)):
-            nranks   = len(phase.cs_n)
+            nranks = len(phase.cs_n)
             rankbits = log2_int(nranks)
             if hasattr(phase, "reset_n"):
-                self.comb += phase.reset_n.eq(1)
-            m.d.comb += phase.cke.eq(Replicate(Signal(reset=1), nranks))
+                m.d.comb += phase.reset_n.eq(1)
+            m.d.comb += phase.clk_en.eq(Repl(Signal(reset=1), nranks))
             if hasattr(phase, "odt"):
                 # FIXME: add dynamic drive for multi-rank (will be needed for high frequencies)
-                m.d.comb += phase.odt.eq(Replicate(Signal(reset=1), nranks))
+                m.d.comb += phase.odt.eq(Repl(Signal(reset=1), nranks))
             if rankbits:
                 rank_decoder = Decoder(nranks)
-                self.submodules += rank_decoder
-                m.d.comb += rank_decoder.i.eq((Array(cmd.ba[-rankbits:] for cmd in commands)[sel]))
-                if i == 0: # Select all ranks on refresh.
+                m.submodules += rank_decoder
+                m.d.comb += rank_decoder.i.eq(
+                    (Array(cmd.ba[-rankbits:] for cmd in commands)[sel]))
+                if i == 0:  # Select all ranks on refresh.
                     with m.If(sel == STEER_REFRESH):
                         m.d.sync += phase.cs_n.eq(0)
                     with m.Else():
                         m.d.sync += phase.cs_n.eq(~rank_decoder.o)
                 else:
                     m.d.sync += phase.cs_n.eq(~rank_decoder.o)
-                m.d.sync += phase.bank.eq(Array(cmd.ba[:-rankbits] for cmd in commands)[sel])
+                m.d.sync += phase.bank.eq(Array(cmd.ba[:-rankbits]
+                                                for cmd in commands)[sel])
             else:
                 m.d.sync += [
                     phase.cs_n.eq(0),
@@ -189,9 +191,12 @@ class _Steerer(Elaboratable):
 
             m.d.sync += [
                 phase.address.eq(Array(cmd.a for cmd in commands)[sel]),
-                phase.cas_n.eq(~Array(valid_and(cmd, "cas") for cmd in commands)[sel]),
-                phase.ras_n.eq(~Array(valid_and(cmd, "ras") for cmd in commands)[sel]),
-                phase.we_n.eq(~Array(valid_and(cmd, "we") for cmd in commands)[sel])
+                phase.cas_n.eq(~Array(valid_and(cmd, "cas")
+                                      for cmd in commands)[sel]),
+                phase.ras_n.eq(~Array(valid_and(cmd, "ras")
+                                      for cmd in commands)[sel]),
+                phase.we_n.eq(~Array(valid_and(cmd, "we")
+                                     for cmd in commands)[sel])
             ]
 
             rddata_ens = Array(valid_and(cmd, "is_read") for cmd in commands)
@@ -203,9 +208,28 @@ class _Steerer(Elaboratable):
 
         return m
 
-# Multiplexer --------------------------------------------------------------------------------------
+class _AntiStarvation(Elaboratable):
+    def __init__(self, timeout):
+        self.en = Signal()
+        self.max_time = Signal()
+
+    def elaborate(self, platform):
+        m = Module()
+
+        if timeout > 0:
+            t = timeout - 1
+            time = Signal(range(t+1))
+            m.d.comb += max_time.eq(time == 0)
+            with m.If(~en):
+                m.d.sync += time.eq(t)
+            with m.Elif(~max_time):
+                m.d.sync += time.eq(time - 1)
+        else:
+            m.d.comb += max_time.eq(0)
+
+        return m
 
-class Multiplexer(Peripheral, Elaboratable):
+class Multiplexer(Elaboratable):
     """Multplexes requets from BankMachines to DFI
 
     This module multiplexes requests from BankMachines (and Refresher) and
@@ -226,13 +250,28 @@ class Multiplexer(Peripheral, Elaboratable):
     interface : LiteDRAMInterface
         Data interface connected directly to LiteDRAMCrossbar
     """
+
     def __init__(self,
-            settings,
-            bank_machines,
-            refresher,
-            dfi,
-            interface):
+                 settings,
+                 bank_machines,
+                 refresher,
+                 dfi,
+                 interface):
         assert(settings.phy.nphases == len(dfi.phases))
+        self._settings = settings
+        self._bank_machines = bank_machines
+        self._refresher = refresher
+        self._dfi = dfi
+        self._interface = interface
+
+    def elaborate(self, platform):
+        m = Module()
+
+        settings = self._settings
+        bank_machines = self._bank_machines
+        refresher = self._refresher
+        dfi = self._dfi
+        interface = self._interface
 
         ras_allowed = Signal(reset=1)
         cas_allowed = Signal(reset=1)
@@ -241,6 +280,9 @@ class Multiplexer(Peripheral, Elaboratable):
         requests = [bm.cmd for bm in bank_machines]
         m.submodules.choose_cmd = choose_cmd = _CommandChooser(requests)
         m.submodules.choose_req = choose_req = _CommandChooser(requests)
+        for i, request in enumerate(requests):
+            m.d.comb += request.ready.eq(
+                choose_cmd.ready[i] | choose_req.ready[i])
         if settings.phy.nphases == 1:
             # When only 1 phase, use choose_req for all requests
             choose_cmd = choose_req
@@ -292,28 +334,17 @@ class Multiplexer(Peripheral, Elaboratable):
         ]
 
         # Anti Starvation --------------------------------------------------------------------------
+        m.submodules.read_antistarvation = read_antistarvation = _AntiStarvation(settings.read_time)
+        read_time_en = read_antistarvation.en
+        max_read_time = read_antistarvation.max_time
 
-        def anti_starvation(timeout):
-            en = Signal()
-            max_time = Signal()
-            if timeout:
-                t = timeout - 1
-                time = Signal(range(t+1))
-                m.d.comb += max_time.eq(time == 0)
-                m.d.sync += If(~en,
-                        time.eq(t)
-                    ).Elif(~max_time,
-                        time.eq(time - 1)
-                    )
-            else:
-                m.d.comb += max_time.eq(0)
-            return en, max_time
-
-        read_time_en,   max_read_time = anti_starvation(settings.read_time)
-        write_time_en, max_write_time = anti_starvation(settings.write_time)
+        m.submodules.write_antistarvation = write_antistarvation = _AntiStarvation(settings.write_time)
+        write_time_en = write_antistarvation.en
+        max_write_time = write_antistarvation.max_time
 
         # Refresh ----------------------------------------------------------------------------------
-        m.d.comb += [bm.refresh_req.eq(refresher.cmd.valid) for bm in bank_machines]
+        m.d.comb += [bm.refresh_req.eq(refresher.cmd.valid)
+                     for bm in bank_machines]
         go_to_refresh = Signal()
         bm_refresh_gnts = [bm.refresh_gnt for bm in bank_machines]
         m.d.comb += go_to_refresh.eq(reduce(and_, bm_refresh_gnts))
@@ -357,11 +388,13 @@ class Multiplexer(Peripheral, Elaboratable):
                 ]
 
                 with m.If(settings.phy.nphases == 1):
-                    m.d.comb += choose_req.cmd.ready.eq(cas_allowed & (~choose_req.activate() | ras_allowed))
+                    m.d.comb += choose_req.cmd.ready.eq(
+                        cas_allowed & (~choose_req.activate() | ras_allowed))
                 with m.Else():
                     m.d.comb += [
                         choose_cmd.want_activates.eq(ras_allowed),
-                        choose_cmd.cmd.ready.eq(~choose_cmd.activate() | ras_allowed),
+                        choose_cmd.cmd.ready.eq(
+                            ~choose_cmd.activate() | ras_allowed),
                         choose_req.cmd.ready.eq(cas_allowed),
                     ]
 
@@ -381,11 +414,13 @@ class Multiplexer(Peripheral, Elaboratable):
                 ]
 
                 with m.If(settings.phy.nphases == 1):
-                    m.d.comb += choose_req.cmd.ready.eq(cas_allowed & (~choose_req.activate() | ras_allowed))
+                    m.d.comb += choose_req.cmd.ready.eq(
+                        cas_allowed & (~choose_req.activate() | ras_allowed))
                 with m.Else():
                     m.d.comb += [
                         choose_cmd.want_activates.eq(ras_allowed),
-                        choose_cmd.cmd.ready.eq(~choose_cmd.activate() | ras_allowed),
+                        choose_cmd.cmd.ready.eq(
+                            ~choose_cmd.activate() | ras_allowed),
                         choose_req.cmd.ready.eq(cas_allowed),
                     ]
 
@@ -407,10 +442,8 @@ class Multiplexer(Peripheral, Elaboratable):
             with m.State("WTR"):
                 with m.If(twtrcon.ready):
                     m.next = "Read"
-            
+
             # TODO: reduce this, actual limit is around (cl+1)/nphases
-            delayed_enter(m, "RTW", "WRITE", settings.phy.read_latency-1)
+            delayed_enter(m, "RTW", "Write", settings.phy.read_latency-1)
 
-        if settings.with_bandwidth:
-            data_width = settings.phy.dfi_databits*settings.phy.nphases
-            self.submodules.bandwidth = Bandwidth(self.choose_req.cmd, data_width)
+        return m