Port FakePHY to nMigen
authorJean THOMAS <git0@pub.jeanthomas.me>
Wed, 8 Jul 2020 16:42:39 +0000 (18:42 +0200)
committerJean THOMAS <git0@pub.jeanthomas.me>
Wed, 8 Jul 2020 16:42:39 +0000 (18:42 +0200)
gram/phy/fakephy.py

index 79354b2b56193ca7f25aefcedc14be0eb917d1e8..60716bb9de41779e7c03ddfedc5745eff0be29f3 100644 (file)
@@ -6,8 +6,8 @@
 # TODO:
 # - add multirank support.
 
-from nmigen.compat import *
-from nmigen.compat.fhdl.module import CompatModule
+from nmigen import *
+from nmigen.utils import log2_int
 
 from gram.common import burst_lengths
 from gram.phy.dfi import *
@@ -27,7 +27,7 @@ def Display(*args):
 
 # Bank Model ---------------------------------------------------------------------------------------
 
-class BankModel(CompatModule):
+class BankModel(Elaboratable):
     def __init__(self, data_width, nrows, ncols, burst_length, nphases, we_granularity, init):
         self.activate     = Signal()
         self.activate_row = Signal(range(nrows))
@@ -41,82 +41,105 @@ class BankModel(CompatModule):
         self.read         = Signal()
         self.read_col     = Signal(range(ncols))
         self.read_data    = Signal(data_width)
-
-        # # #
+        self.nphases = nphases
+        self.nrows = nrows
+        self.ncols = ncols
+        self.burst_length = burst_length
+        self.data_width = data_width
+        self.we_granularity = we_granularity
+        self.init = init
+
+    def elaborate(self, platform):
+        m = Module()
+
+        nrows = self.nrows
+        ncols = self.ncols
+        burst_length = self.burst_length
+        data_width = self.data_width
+        we_granularity = self.we_granularity
+        init = self.init
 
         active = Signal()
         row    = Signal(range(nrows))
 
-        self.sync += \
-            If(self.precharge,
-                active.eq(0),
-            ).Elif(self.activate,
+        with m.If(self.precharge):
+            m.d.sync += active.eq(0)
+        with m.Elif(self.activate):
+            m.d.sync += [
                 active.eq(1),
-                row.eq(self.activate_row)
-            )
+                row.eq(self.activate_row),
+            ]
 
-        bank_mem_len   = nrows*ncols//(burst_length*nphases)
-        mem            = Memory(width=data_width, depth=bank_mem_len, init=init)
-        write_port     = mem.get_port(write_capable=True, we_granularity=we_granularity)
-        read_port      = mem.get_port(async_read=True)
-        self.specials += mem, read_port, write_port
+        bank_mem_len   = nrows*ncols//(burst_length*self.nphases)
+        mem            = Memory(width=data_width, depth=bank_mem_len, init=init)
+        write_port     = mem.get_port(write_capable=True, we_granularity=we_granularity)
+        read_port      = mem.get_port(async_read=True)
+        # m.submodules += mem, read_port, write_port
 
         wraddr         = Signal(range(bank_mem_len))
         rdaddr         = Signal(range(bank_mem_len))
 
-        self.comb += [
-            wraddr.eq((row*ncols | self.write_col)[log2_int(burst_length*nphases):]),
-            rdaddr.eq((row*ncols | self.read_col)[log2_int(burst_length*nphases):]),
+        m.d.comb += [
+            wraddr.eq((row*ncols | self.write_col)[log2_int(burst_length*self.nphases):]),
+            rdaddr.eq((row*ncols | self.read_col)[log2_int(burst_length*self.nphases):]),
         ]
 
-        self.comb += [
-            If(active,
-                write_port.adr.eq(wraddr),
-                write_port.dat_w.eq(self.write_data),
-                If(we_granularity,
-                    write_port.we.eq(Replicate(self.write, data_width//8) & ~self.write_mask),
-                ).Else(
-                    write_port.we.eq(self.write),
-                ),
-                If(self.read,
-                    read_port.adr.eq(rdaddr),
-                    self.read_data.eq(read_port.dat_r)
-                )
-            )
-        ]
+        with m.If(active):
+            # m.d.comb += [
+            #     write_port.adr.eq(wraddr),
+            #     write_port.dat_w.eq(self.write_data),
+            # ]
+
+            # with m.If(we_granularity):
+            #     m.d.comb += write_port.we.eq(Replicate(self.write, data_width//8) & ~self.write_mask)
+            # with m.Else():
+            #     m.d.comb += write_port.we.eq(self.write)
+
+            with m.If(self.read):
+                # m.d.comb += [
+                #     read_port.adr.eq(rdaddr),
+                #     self.read_data.eq(read_port.dat_r),
+                # ]
+                m.d.comb += self.read_data.eq(0xDEADBEEF)
+
+        return m
 
 # DFI Phase Model ----------------------------------------------------------------------------------
 
-class DFIPhaseModel(CompatModule):
+class DFIPhaseModel(Elaboratable):
     def __init__(self, dfi, n):
-        phase = dfi.phases[n]
+        self.phase = dfi.phases[n]
 
-        self.bank         = phase.bank
-        self.address      = phase.address
+        self.bank         = self.phase.bank
+        self.address      = self.phase.address
 
-        self.wrdata       = phase.wrdata
-        self.wrdata_mask  = phase.wrdata_mask
+        self.wrdata       = self.phase.wrdata
+        self.wrdata_mask  = self.phase.wrdata_mask
 
-        self.rddata       = phase.rddata
-        self.rddata_valid = phase.rddata_valid
+        self.rddata       = self.phase.rddata
+        self.rddata_valid = self.phase.rddata_valid
 
         self.activate     = Signal()
         self.precharge    = Signal()
         self.write        = Signal()
         self.read         = Signal()
 
-        # # #
+    def elaborate(self, platform):
+        m = Module()
 
-        self.comb += [
-            If(~phase.cs_n & ~phase.ras_n & phase.cas_n,
-                self.activate.eq(phase.we_n),
-                self.precharge.eq(~phase.we_n)
-            ),
-            If(~phase.cs_n & phase.ras_n & ~phase.cas_n,
-                self.write.eq(~phase.we_n),
-                self.read.eq(phase.we_n)
-            )
-        ]
+        with m.If(~self.phase.cs_n & ~self.phase.ras_n & self.phase.cas_n):
+            m.d.comb += [
+                self.activate.eq(self.phase.we_n),
+                self.precharge.eq(~self.phase.we_n),
+            ]
+
+        with m.If(~self.phase.cs_n & self.phase.ras_n & ~self.phase.cas_n):
+            m.d.comb += [
+                self.write.eq(~self.phase.we_n),
+                self.read.eq(self.phase.we_n),
+            ]
+
+        return m
 
 # DFI Timings Checker ------------------------------------------------------------------------------
 
@@ -135,7 +158,7 @@ class TimingRule:
         self.delay = delay
 
 
-class DFITimingsChecker(CompatModule):
+class DFITimingsChecker(Elaboratable):
     CMDS = [
         # Name, cs & ras & cas & we value
         ("PRE",  "0010"), # Precharge
@@ -233,11 +256,26 @@ class DFITimingsChecker(CompatModule):
         self.prepare_timings(timings, refresh_mode, memtype)
         self.add_cmds()
         self.add_rules()
+        self.nphases = nphases
+        self.nbanks = nbanks
+        self.dfi = dfi
+        self.timings = timings
+        self.refresh_mode = refresh_mode
+        self.memtype = memtype
+        self.verbose = verbose
+
+    def elaborate(self, platform):
+        m = Module()
 
         cnt = Signal(64)
-        self.sync += cnt.eq(cnt + nphases)
+        m.d.sync += cnt.eq(cnt+self.nphases)
 
-        phases = dfi.phases
+        phases = self.dfi.phases
+        nbanks = self.nbanks
+        timings = self.timings
+        refresh_mode = self.refresh_mode
+        memtype = self.memtype
+        verbose = self.verbose
 
         last_cmd_ps = [[Signal.like(cnt) for _ in range(len(self.cmds))] for _ in range(nbanks)]
         last_cmd    = [Signal(4) for i in range(nbanks)]
@@ -245,139 +283,160 @@ class DFITimingsChecker(CompatModule):
         act_ps   = Array([Signal().like(cnt) for i in range(4)])
         act_curr = Signal(range(4))
 
-        ref_issued = Signal(nphases)
+        ref_issued = Signal(self.nphases)
 
         for np, phase in enumerate(phases):
             ps = Signal().like(cnt)
-            self.comb += ps.eq((cnt + np)*self.timings["tCK"])
+            m.d.comb += ps.eq((cnt + np)*int(self.timings["tCK"]))
             state = Signal(4)
-            self.comb += state.eq(Cat(phase.we_n, phase.cas_n, phase.ras_n, phase.cs_n))
+            m.d.comb += state.eq(Cat(phase.we_n, phase.cas_n, phase.ras_n, phase.cs_n))
             all_banks = Signal()
 
-            self.comb += all_banks.eq(
+            m.d.comb += all_banks.eq(
                 (self.cmds["REF"].enc == state) |
                 ((self.cmds["PRE"].enc == state) & phase.address[10])
             )
 
             # tREFI
-            self.comb += ref_issued[np].eq(self.cmds["REF"].enc == state)
+            m.d.comb += ref_issued[np].eq(self.cmds["REF"].enc == state)
 
             # Print debug information
-            if verbose:
-                for _, cmd in self.cmds.items():
-                    self.sync += [
-                        If(state == cmd.enc,
-                            If(all_banks,
-                                Display("[%016dps] P%0d " + cmd.name, ps, np)
-                            ).Else(
-                                Display("[%016dps] P%0d B%0d " + cmd.name, ps, np, phase.bank)
-                            )
-                        )
-                    ]
+            # TODO: find a way to bring back logging
+            # if verbose:
+            #     for _, cmd in self.cmds.items():
+            #         self.sync += [
+            #             If(state == cmd.enc,
+            #                 If(all_banks,
+            #                     Display("[%016dps] P%0d " + cmd.name, ps, np)
+            #                 ).Else(
+            #                     Display("[%016dps] P%0d B%0d " + cmd.name, ps, np, phase.bank)
+            #                 )
+            #             )
+            #         ]
 
             # Bank command monitoring
             for i in range(nbanks):
                 for _, curr in self.cmds.items():
                     cmd_recv = Signal()
-                    self.comb += cmd_recv.eq(((phase.bank == i) | all_banks) & (state == curr.enc))
+                    m.d.comb += cmd_recv.eq(((phase.bank == i) | all_banks) & (state == curr.enc))
 
                     # Checking rules from self.rules
-                    for _, prev in self.cmds.items():
-                        for rule in self.rules:
-                            if rule.prev == prev.name and rule.curr == curr.name:
-                                self.sync += [
-                                    If(cmd_recv & (last_cmd[i] == prev.enc) &
-                                       (ps < (last_cmd_ps[i][prev.idx] + rule.delay)),
-                                        Display("[%016dps] {} violation on bank %0d".format(rule.name), ps, i)
-                                    )
-                                ]
+                    # TODO: find a way to bring back logging
+                    # for _, prev in self.cmds.items():
+                    #     for rule in self.rules:
+                    #         if rule.prev == prev.name and rule.curr == curr.name:
+                    #             self.sync += [
+                    #                 If(cmd_recv & (last_cmd[i] == prev.enc) &
+                    #                    (ps < (last_cmd_ps[i][prev.idx] + rule.delay)),
+                    #                     Display("[%016dps] {} violation on bank %0d".format(rule.name), ps, i)
+                    #                 )
+                    #             ]
 
                     # Save command timestamp in an array
-                    self.sync += If(cmd_recv, last_cmd_ps[i][curr.idx].eq(ps), last_cmd[i].eq(state))
+                    with m.If(cmd_recv):
+                        m.d.comb += [
+                            last_cmd_ps[i][curr.idx].eq(ps),
+                            last_cmd[i].eq(state),
+                        ]
 
                     # tRRD & tFAW
                     if curr.name == "ACT":
                         act_next = Signal().like(act_curr)
-                        self.comb += act_next.eq(act_curr+1)
+                        m.d.comb += act_next.eq(act_curr+1)
 
                         # act_curr points to newest ACT timestamp
-                        self.sync += [
-                            If(cmd_recv & (ps < (act_ps[act_curr] + self.timings["tRRD"])),
-                                Display("[%016dps] tRRD violation on bank %0d", ps, i)
-                            )
-                        ]
+                        # TODO: find a way to bring back logging
+                        # self.sync += [
+                        #     If(cmd_recv & (ps < (act_ps[act_curr] + self.timings["tRRD"])),
+                        #         Display("[%016dps] tRRD violation on bank %0d", ps, i)
+                        #     )
+                        # ]
 
                         # act_next points to the oldest ACT timestamp
-                        self.sync += [
-                            If(cmd_recv & (ps < (act_ps[act_next] + self.timings["tFAW"])),
-                                Display("[%016dps] tFAW violation on bank %0d", ps, i)
-                            )
-                        ]
+                        # TODO: find a way to bring back logging
+                        # self.sync += [
+                        #     If(cmd_recv & (ps < (act_ps[act_next] + self.timings["tFAW"])),
+                        #         Display("[%016dps] tFAW violation on bank %0d", ps, i)
+                        #     )
+                        # ]
 
                         # Save ACT timestamp in a circular buffer
-                        self.sync += If(cmd_recv, act_ps[act_next].eq(ps), act_curr.eq(act_next))
+                        with m.If(cmd_recv):
+                            m.d.sync += [
+                                act_ps[act_next].eq(ps),
+                                act_curr.eq(act_next),
+                            ]
 
         # tREFI
         ref_ps      = Signal().like(cnt)
         ref_ps_mod  = Signal().like(cnt)
-        #ref_ps_diff = Signal(range(-2**63, 2**63))
         ref_ps_diff = Signal(signed(64))
         curr_diff   = Signal().like(ref_ps_diff)
 
-        self.comb += curr_diff.eq(ps - (ref_ps + self.timings["tREFI"]))
+        m.d.comb += curr_diff.eq(ps - (ref_ps + int(self.timings["tREFI"])))
 
         # Work in 64ms periods
-        self.sync += [
-            If(ref_ps_mod < int(64e9),
-                ref_ps_mod.eq(ref_ps_mod + nphases * self.timings["tCK"])
-            ).Else(
-                ref_ps_mod.eq(0)
-            )
-        ]
+        with m.If(ref_ps_mod < int(64e9)):
+            m.d.sync += ref_ps_mod.eq(ref_ps_mod + int(self.nphases * self.timings["tCK"]))
+        with m.Else():
+            m.d.sync += ref_ps_mod.eq(0)
 
         # Update timestamp and difference
-        self.sync += If(ref_issued != 0, ref_ps.eq(ps), ref_ps_diff.eq(ref_ps_diff - curr_diff))
+        with m.If(ref_issued != 0):
+            m.d.sync += [
+                ref_ps.eq(ps),
+                ref_ps_diff.eq(ref_ps_diff - curr_diff),
+            ]
 
-        self.sync += [
-            If((ref_ps_mod == 0) & (ref_ps_diff > 0),
-                Display("[%016dps] tREFI violation (64ms period): %0d", ps, ref_ps_diff)
-            )
-        ]
+        # TODO: find a way to bring back logging
+        # self.sync += [
+        #     If((ref_ps_mod == 0) & (ref_ps_diff > 0),
+        #         Display("[%016dps] tREFI violation (64ms period): %0d", ps, ref_ps_diff)
+        #     )
+        # ]
 
         # Report any refresh periods longer than tREFI
-        if verbose:
-            ref_done = Signal()
-            self.sync += [
-                If(ref_issued != 0,
-                    ref_done.eq(1),
-                    If(~ref_done,
-                        Display("[%016dps] Late refresh", ps)
-                    )
-                )
-            ]
-
-            self.sync += [
-                If((curr_diff > 0) & ref_done & (ref_issued == 0),
-                    Display("[%016dps] tREFI violation", ps),
-                    ref_done.eq(0)
-                )
-            ]
+        # TODO: find a way to bring back logging
+        # if verbose:
+        #     ref_done = Signal()
+        #     self.sync += [
+        #         If(ref_issued != 0,
+        #             ref_done.eq(1),
+        #             If(~ref_done,
+        #                 Display("[%016dps] Late refresh", ps)
+        #             )
+        #         )
+        #     ]
+
+        #     self.sync += [
+        #         If((curr_diff > 0) & ref_done & (ref_issued == 0),
+        #             Display("[%016dps] tREFI violation", ps),
+        #             ref_done.eq(0)
+        #         )
+        #     ]
 
         # There is a maximum delay between refreshes on >=DDR
         ref_limit = {"1x": 9, "2x": 17, "4x": 36}
         if memtype != "SDR":
             refresh_mode = "1x" if refresh_mode is None else refresh_mode
             ref_done = Signal()
-            self.sync += If(ref_issued != 0, ref_done.eq(1))
-            self.sync += [
-                If((ref_issued == 0) & ref_done &
-                   (ref_ps > (ps + ref_limit[refresh_mode] * self.timings['tREFI'])),
-                    Display("[%016dps] tREFI violation (too many postponed refreshes)", ps),
-                    ref_done.eq(0)
-                )
-            ]
-
-class FakePHY(CompatModule):
+            with m.If(ref_issued != 0):
+                m.d.sync += ref_done.eq(1)
+
+            with m.If((ref_issued == 0) & ref_done &
+                   (ref_ps > (ps + int(ref_limit[refresh_mode] * self.timings['tREFI'])))):
+                m.d.sync += ref_done.eq(0)
+            # self.sync += [
+            #     If((ref_issued == 0) & ref_done &
+            #        (ref_ps > (ps + ref_limit[refresh_mode] * self.timings['tREFI'])),
+            #         Display("[%016dps] tREFI violation (too many postponed refreshes)", ps),
+            #         ref_done.eq(0)
+            #     )
+            # ]
+
+        return m
+
+class FakePHY(Elaboratable):
     def __prepare_bank_init_data(self, init, nbanks, nrows, ncols, data_width, address_mapping):
         mem_size          = (self.settings.databits//8)*(nrows*ncols*nbanks)
         bank_size         = mem_size // nbanks
@@ -438,7 +497,7 @@ class FakePHY(CompatModule):
         verbosity              = SDRAM_VERBOSE_OFF):
 
         # Parameters -------------------------------------------------------------------------------
-        burst_length = {
+        self.burst_length = {
             "SDR":   1,
             "DDR":   2,
             "LPDDR": 2,
@@ -447,38 +506,45 @@ class FakePHY(CompatModule):
             "DDR4":  2,
             }[settings.memtype]
 
-        addressbits   = module.geom_settings.addressbits
-        bankbits      = module.geom_settings.bankbits
-        rowbits       = module.geom_settings.rowbits
-        colbits       = module.geom_settings.colbits
+        self.addressbits = module.geom_settings.addressbits
+        self.bankbits = module.geom_settings.bankbits
+        self.rowbits = module.geom_settings.rowbits
+        self.colbits = module.geom_settings.colbits
 
         self.settings = settings
-        self.module   = module
+        self.module = module
+
+        self.verbosity = verbosity
+        self.clk_freq = clk_freq
+        self.we_granularity = we_granularity
+
+        self.init = init
 
         # DFI Interface ----------------------------------------------------------------------------
         self.dfi = Interface(
-            addressbits = addressbits,
-            bankbits    = bankbits,
+            addressbits = self.addressbits,
+            bankbits    = self.bankbits,
             nranks      = self.settings.nranks,
             databits    = self.settings.dfi_databits,
             nphases     = self.settings.nphases
         )
 
-        # # #
+    def elaborate(self, platform):
+        m = Module()
 
         nphases    = self.settings.nphases
-        nbanks     = 2**bankbits
-        nrows      = 2**rowbits
-        ncols      = 2**colbits
+        nbanks     = 2**self.bankbits
+        nrows      = 2**self.rowbits
+        ncols      = 2**self.colbits
         data_width = self.settings.dfi_databits*self.settings.nphases
 
         # DFI phases -------------------------------------------------------------------------------
         phases = [DFIPhaseModel(self.dfi, n) for n in range(self.settings.nphases)]
-        self.submodules += phases
+        m.submodules += phases
 
         # DFI timing checker -----------------------------------------------------------------------
-        if verbosity > SDRAM_VERBOSE_OFF:
-            timings = {"tCK": (1e9 / clk_freq) / nphases}
+        if self.verbosity > SDRAM_VERBOSE_OFF:
+            timings = {"tCK": (1e9 / self.clk_freq) / nphases}
 
             for name in _speedgrade_timings + _technology_timings:
                 timings[name] = self.module.get(name)
@@ -489,16 +555,16 @@ class FakePHY(CompatModule):
                 nphases      = nphases,
                 timings      = timings,
                 refresh_mode = self.module.timing_settings.fine_refresh_mode,
-                memtype      = settings.memtype,
-                verbose      = verbosity > SDRAM_VERBOSE_DBG)
-            self.submodules += timing_checker
+                memtype      = self.settings.memtype,
+                verbose      = self.verbosity > SDRAM_VERBOSE_DBG)
+            m.submodules += timing_checker
 
         # Bank init data ---------------------------------------------------------------------------
         bank_init  = [None for i in range(nbanks)]
 
-        if init:
+        if self.init:
             bank_init = self.__prepare_bank_init_data(
-                init            = init,
+                init            = self.init,
                 nbanks          = nbanks,
                 nrows           = nrows,
                 ncols           = ncols,
@@ -511,48 +577,46 @@ class FakePHY(CompatModule):
             data_width     = data_width,
             nrows          = nrows,
             ncols          = ncols,
-            burst_length   = burst_length,
+            burst_length   = self.burst_length,
             nphases        = nphases,
-            we_granularity = we_granularity,
+            we_granularity = self.we_granularity,
             init           = bank_init[i]) for i in range(nbanks)]
-        self.submodules += banks
+        m.submodules += banks
 
         # Connect DFI phases to Banks (CMDs, Write datapath) ---------------------------------------
         for nb, bank in enumerate(banks):
             # Bank activate
             activates = Signal(len(phases))
-            cases     = {}
-            for np, phase in enumerate(phases):
-                self.comb += activates[np].eq(phase.activate)
-                cases[2**np] = [
-                    bank.activate.eq(phase.bank == nb),
-                    bank.activate_row.eq(phase.address)
-                ]
-            self.comb += Case(activates, cases)
+            with m.Switch(activates):
+                for np, phase in enumerate(phases):
+                    m.d.comb += activates[np].eq(phase.activate)
+                    with m.Case(2**np):
+                        m.d.comb +=  [
+                            bank.activate.eq(phase.bank == nb),
+                            bank.activate_row.eq(phase.address)
+                        ]
 
             # Bank precharge
             precharges = Signal(len(phases))
-            cases      = {}
-            for np, phase in enumerate(phases):
-                self.comb += precharges[np].eq(phase.precharge)
-                cases[2**np] = [
-                    bank.precharge.eq((phase.bank == nb) | phase.address[10])
-                ]
-            self.comb += Case(precharges, cases)
+            with m.Switch(precharges):
+                for np, phase in enumerate(phases):
+                    m.d.comb += precharges[np].eq(phase.precharge)
+                    with m.Case(2**np):
+                        m.d.comb += bank.precharge.eq((phase.bank == nb) | phase.address[10])
 
             # Bank writes
             bank_write = Signal()
             bank_write_col = Signal(range(ncols))
             writes = Signal(len(phases))
-            cases  = {}
-            for np, phase in enumerate(phases):
-                self.comb += writes[np].eq(phase.write)
-                cases[2**np] = [
-                    bank_write.eq(phase.bank == nb),
-                    bank_write_col.eq(phase.address)
-                ]
-            self.comb += Case(writes, cases)
-            self.comb += [
+            with m.Switch(writes):
+                for np, phase in enumerate(phases):
+                    m.d.comb += writes[np].eq(phase.write)
+                    with m.Case(2**np):
+                        m.d.comb += [
+                            bank_write.eq(phase.bank == nb),
+                            bank_write_col.eq(phase.address)
+                        ]
+            m.d.comb += [
                 bank.write_data.eq(Cat(*[phase.wrdata for phase in phases])),
                 bank.write_mask.eq(Cat(*[phase.wrdata_mask for phase in phases]))
             ]
@@ -561,33 +625,33 @@ class FakePHY(CompatModule):
             for i in range(self.settings.write_latency):
                 new_bank_write     = Signal()
                 new_bank_write_col = Signal(range(ncols))
-                self.sync += [
+                m.d.sync += [
                     new_bank_write.eq(bank_write),
                     new_bank_write_col.eq(bank_write_col)
                 ]
                 bank_write = new_bank_write
                 bank_write_col = new_bank_write_col
 
-            self.comb += [
+            m.d.comb += [
                 bank.write.eq(bank_write),
                 bank.write_col.eq(bank_write_col)
             ]
 
             # Bank reads
             reads = Signal(len(phases))
-            cases = {}
-            for np, phase in enumerate(phases):
-                self.comb += reads[np].eq(phase.read)
-                cases[2**np] = [
-                    bank.read.eq(phase.bank == nb),
-                    bank.read_col.eq(phase.address)
-            ]
-            self.comb += Case(reads, cases)
+            with m.Switch(reads):
+                for np, phase in enumerate(phases):
+                    m.d.comb += reads[np].eq(phase.read)
+                    with m.Case(2**np):
+                        m.d.comb += [
+                            bank.read.eq(phase.bank == nb),
+                            bank.read_col.eq(phase.address),
+                        ]
 
         # Connect Banks to DFI phases (CMDs, Read datapath) ----------------------------------------
         banks_read      = Signal()
         banks_read_data = Signal(data_width)
-        self.comb += [
+        m.d.comb += [
             banks_read.eq(reduce(or_, [bank.read for bank in banks])),
             banks_read_data.eq(reduce(or_, [bank.read_data for bank in banks]))
         ]
@@ -596,14 +660,16 @@ class FakePHY(CompatModule):
         for i in range(self.settings.read_latency):
             new_banks_read      = Signal()
             new_banks_read_data = Signal(data_width)
-            self.sync += [
+            m.d.sync += [
                 new_banks_read.eq(banks_read),
                 new_banks_read_data.eq(banks_read_data)
             ]
             banks_read      = new_banks_read
             banks_read_data = new_banks_read_data
 
-        self.comb += [
+        m.d.comb += [
             Cat(*[phase.rddata_valid for phase in phases]).eq(banks_read),
             Cat(*[phase.rddata for phase in phases]).eq(banks_read_data)
         ]
+
+        return m