Rework RAM port for nMigen compliance
[gram.git] / gram / core / crossbar.py
index e71650f369cdaee4ae3a94a19b7d9f926a0e6c68..c81abc77e1a597ccf7e4ec9d09e8c3912151d9e6 100644 (file)
@@ -19,7 +19,7 @@ import gram.stream as stream
 
 # LiteDRAMCrossbar ---------------------------------------------------------------------------------
 
-class gramCrossbar(Module):
+class gramCrossbar(Elaboratable):
     """Multiplexes LiteDRAMController (slave) between ports (masters)
 
     To get a port to LiteDRAM, use the `get_port` method. It handles data width
@@ -72,17 +72,16 @@ class gramCrossbar(Module):
         self.rank_bits = log2_int(self.nranks, False)
 
         self.masters = []
+        self._pending_submodules = []
 
     def get_port(self, mode="both", data_width=None, clock_domain="sys", reverse=False):
-        if self.finalized:
-            raise FinalizeError
-
         if data_width is None:
             # use internal data_width when no width adaptation is requested
             data_width = self.controller.data_width
+            print("data_width=", data_width)
 
         # Crossbar port ----------------------------------------------------------------------------
-        port = LiteDRAMNativePort(
+        port = gramNativePort(
             mode          = mode,
             address_width = self.rca_bits + self.bank_bits - self.rank_bits,
             data_width    = self.controller.data_width,
@@ -92,13 +91,13 @@ class gramCrossbar(Module):
 
         # Clock domain crossing --------------------------------------------------------------------
         if clock_domain != "sys":
-            new_port = LiteDRAMNativePort(
+            new_port = gramNativePort(
                 mode          = mode,
                 address_width = port.address_width,
                 data_width    = port.data_width,
                 clock_domain  = clock_domain,
                 id            = port.id)
-            self.submodules += LiteDRAMNativePortCDC(new_port, port)
+            self._pending_submodules.append(gramNativePortCDC(new_port, port))
             port = new_port
 
         # Data width convertion --------------------------------------------------------------------
@@ -107,34 +106,38 @@ class gramCrossbar(Module):
                 addr_shift = -log2_int(data_width//self.controller.data_width)
             else:
                 addr_shift = log2_int(self.controller.data_width//data_width)
-            new_port = LiteDRAMNativePort(
+            new_port = gramNativePort(
                 mode          = mode,
                 address_width = port.address_width + addr_shift,
                 data_width    = data_width,
                 clock_domain  = clock_domain,
                 id            = port.id)
-            self.submodules += ClockDomainsRenamer(clock_domain)(
-                LiteDRAMNativePortConverter(new_port, port, reverse))
+            self._pending_submodules.append(ClockDomainsRenamer(clock_domain)(
+                gramNativePortConverter(new_port, port, reverse)))
             port = new_port
 
         return port
 
-    def do_finalize(self):
+    def elaborate(self, platform):
+        m = Module()
+
+        m.submodules += self._pending_submodules
+
         controller = self.controller
         nmasters   = len(self.masters)
 
         # Address mapping --------------------------------------------------------------------------
         cba_shifts = {"ROW_BANK_COL": controller.settings.geom.colbits - controller.address_align}
         cba_shift = cba_shifts[controller.settings.address_mapping]
-        m_ba      = [m.get_bank_address(self.bank_bits, cba_shift)for m in self.masters]
-        m_rca     = [m.get_row_column_address(self.bank_bits, self.rca_bits, cba_shift) for m in self.masters]
+        m_ba      = [master.get_bank_address(self.bank_bits, cba_shift) for master in self.masters]
+        m_rca     = [master.get_row_column_address(self.bank_bits, self.rca_bits, cba_shift) for master in self.masters]
 
         master_readys       = [0]*nmasters
         master_wdata_readys = [0]*nmasters
         master_rdata_valids = [0]*nmasters
 
-        arbiters = [roundrobin.RoundRobin(nmasters, roundrobin.SP_CE) for n in range(self.nbanks)]
-        self.submodules += arbiters
+        arbiters = [RoundRobin(nmasters) for n in range(self.nbanks)]
+        m.submodules += arbiters
 
         for nb, arbiter in enumerate(arbiters):
             bank = getattr(controller, "bank"+str(nb))
@@ -152,13 +155,13 @@ class gramCrossbar(Module):
             # Arbitrate ----------------------------------------------------------------------------
             bank_selected  = [(ba == nb) & ~locked for ba, locked in zip(m_ba, master_locked)]
             bank_requested = [bs & master.cmd.valid for bs, master in zip(bank_selected, self.masters)]
-            self.comb += [
+            m.d.comb += [
                 arbiter.request.eq(Cat(*bank_requested)),
-                arbiter.ce.eq(~bank.valid & ~bank.lock)
+                arbiter.stb.eq(~bank.valid & ~bank.lock)
             ]
 
             # Route requests -----------------------------------------------------------------------
-            self.comb += [
+            m.d.comb += [
                 bank.addr.eq(Array(m_rca)[arbiter.grant]),
                 bank.we.eq(Array(self.masters)[arbiter.grant].cmd.we),
                 bank.valid.eq(Array(bank_requested)[arbiter.grant])
@@ -174,37 +177,40 @@ class gramCrossbar(Module):
         for nm, master_wdata_ready in enumerate(master_wdata_readys):
             for i in range(self.write_latency):
                 new_master_wdata_ready = Signal()
-                self.sync += new_master_wdata_ready.eq(master_wdata_ready)
+                m.d.sync += new_master_wdata_ready.eq(master_wdata_ready)
                 master_wdata_ready = new_master_wdata_ready
             master_wdata_readys[nm] = master_wdata_ready
 
         for nm, master_rdata_valid in enumerate(master_rdata_valids):
             for i in range(self.read_latency):
                 new_master_rdata_valid = Signal()
-                self.sync += new_master_rdata_valid.eq(master_rdata_valid)
+                m.d.sync += new_master_rdata_valid.eq(master_rdata_valid)
                 master_rdata_valid = new_master_rdata_valid
             master_rdata_valids[nm] = master_rdata_valid
 
         for master, master_ready in zip(self.masters, master_readys):
-            self.comb += master.cmd.ready.eq(master_ready)
+            m.d.comb += master.cmd.ready.eq(master_ready)
         for master, master_wdata_ready in zip(self.masters, master_wdata_readys):
-            self.comb += master.wdata.ready.eq(master_wdata_ready)
+            m.d.comb += master.wdata.ready.eq(master_wdata_ready)
         for master, master_rdata_valid in zip(self.masters, master_rdata_valids):
-            self.comb += master.rdata.valid.eq(master_rdata_valid)
+            m.d.comb += master.rdata.valid.eq(master_rdata_valid)
 
         # Route data writes ------------------------------------------------------------------------
-        wdata_cases = {}
-        for nm, master in enumerate(self.masters):
-            wdata_cases[2**nm] = [
-                controller.wdata.eq(master.wdata.data),
-                controller.wdata_we.eq(master.wdata.we)
-            ]
-        wdata_cases["default"] = [
-            controller.wdata.eq(0),
-            controller.wdata_we.eq(0)
-        ]
-        self.comb += Case(Cat(*master_wdata_readys), wdata_cases)
+        with m.Switch(Cat(*master_wdata_readys)):
+            with m.Case():
+                m.d.comb += [
+                    controller.wdata.eq(0),
+                    controller.wdata_we.eq(0),
+                ]
+            for nm, master in enumerate(self.masters):
+                with m.Case(2**nm):
+                    m.d.comb += [
+                        controller.wdata.eq(master.wdata.data),
+                        controller.wdata_we.eq(master.wdata.we),
+                    ]
 
         # Route data reads -------------------------------------------------------------------------
         for master in self.masters:
-            self.comb += master.rdata.data.eq(controller.rdata)
+            m.d.comb += master.rdata.data.eq(controller.rdata)
+
+        return m