Import fake PHY from LiteDRAM (non functionnal ATM)
authorJean THOMAS <git0@pub.jeanthomas.me>
Wed, 8 Jul 2020 15:30:35 +0000 (17:30 +0200)
committerJean THOMAS <git0@pub.jeanthomas.me>
Wed, 8 Jul 2020 15:30:35 +0000 (17:30 +0200)
gram/phy/fakephy.py [new file with mode: 0644]

diff --git a/gram/phy/fakephy.py b/gram/phy/fakephy.py
new file mode 100644 (file)
index 0000000..79354b2
--- /dev/null
@@ -0,0 +1,609 @@
+# This file is Copyright (c) 2015-2020 Florent Kermarrec <florent@enjoy-digital.fr>
+# This file is Copyright (c) 2020 Antmicro <www.antmicro.com>
+# License: BSD
+
+# SDRAM simulation PHY at DFI level tested with SDR/DDR/DDR2/LPDDR/DDR3
+# TODO:
+# - add multirank support.
+
+from nmigen.compat import *
+from nmigen.compat.fhdl.module import CompatModule
+
+from gram.common import burst_lengths
+from gram.phy.dfi import *
+from gram.modules import _speedgrade_timings, _technology_timings
+
+from functools import reduce
+from operator import or_
+
+import struct
+
+SDRAM_VERBOSE_OFF = 0
+SDRAM_VERBOSE_STD = 1
+SDRAM_VERBOSE_DBG = 2
+
+def Display(*args):
+    return Signal().eq(0)
+
+# Bank Model ---------------------------------------------------------------------------------------
+
+class BankModel(CompatModule):
+    def __init__(self, data_width, nrows, ncols, burst_length, nphases, we_granularity, init):
+        self.activate     = Signal()
+        self.activate_row = Signal(range(nrows))
+        self.precharge    = Signal()
+
+        self.write        = Signal()
+        self.write_col    = Signal(range(ncols))
+        self.write_data   = Signal(data_width)
+        self.write_mask   = Signal(data_width//8)
+
+        self.read         = Signal()
+        self.read_col     = Signal(range(ncols))
+        self.read_data    = Signal(data_width)
+
+        # # #
+
+        active = Signal()
+        row    = Signal(range(nrows))
+
+        self.sync += \
+            If(self.precharge,
+                active.eq(0),
+            ).Elif(self.activate,
+                active.eq(1),
+                row.eq(self.activate_row)
+            )
+
+        bank_mem_len   = nrows*ncols//(burst_length*nphases)
+        mem            = Memory(width=data_width, depth=bank_mem_len, init=init)
+        write_port     = mem.get_port(write_capable=True, we_granularity=we_granularity)
+        read_port      = mem.get_port(async_read=True)
+        self.specials += mem, read_port, write_port
+
+        wraddr         = Signal(range(bank_mem_len))
+        rdaddr         = Signal(range(bank_mem_len))
+
+        self.comb += [
+            wraddr.eq((row*ncols | self.write_col)[log2_int(burst_length*nphases):]),
+            rdaddr.eq((row*ncols | self.read_col)[log2_int(burst_length*nphases):]),
+        ]
+
+        self.comb += [
+            If(active,
+                write_port.adr.eq(wraddr),
+                write_port.dat_w.eq(self.write_data),
+                If(we_granularity,
+                    write_port.we.eq(Replicate(self.write, data_width//8) & ~self.write_mask),
+                ).Else(
+                    write_port.we.eq(self.write),
+                ),
+                If(self.read,
+                    read_port.adr.eq(rdaddr),
+                    self.read_data.eq(read_port.dat_r)
+                )
+            )
+        ]
+
+# DFI Phase Model ----------------------------------------------------------------------------------
+
+class DFIPhaseModel(CompatModule):
+    def __init__(self, dfi, n):
+        phase = dfi.phases[n]
+
+        self.bank         = phase.bank
+        self.address      = phase.address
+
+        self.wrdata       = phase.wrdata
+        self.wrdata_mask  = phase.wrdata_mask
+
+        self.rddata       = phase.rddata
+        self.rddata_valid = phase.rddata_valid
+
+        self.activate     = Signal()
+        self.precharge    = Signal()
+        self.write        = Signal()
+        self.read         = Signal()
+
+        # # #
+
+        self.comb += [
+            If(~phase.cs_n & ~phase.ras_n & phase.cas_n,
+                self.activate.eq(phase.we_n),
+                self.precharge.eq(~phase.we_n)
+            ),
+            If(~phase.cs_n & phase.ras_n & ~phase.cas_n,
+                self.write.eq(~phase.we_n),
+                self.read.eq(phase.we_n)
+            )
+        ]
+
+# DFI Timings Checker ------------------------------------------------------------------------------
+
+class SDRAMCMD:
+    def __init__(self, name: str, enc: int, idx: int):
+        self.name = name
+        self.enc  = enc
+        self.idx  = idx
+
+
+class TimingRule:
+    def __init__(self, prev: str, curr: str, delay: int):
+        self.name  = prev + "->" + curr
+        self.prev  = prev
+        self.curr  = curr
+        self.delay = delay
+
+
+class DFITimingsChecker(CompatModule):
+    CMDS = [
+        # Name, cs & ras & cas & we value
+        ("PRE",  "0010"), # Precharge
+        ("REF",  "0001"), # Self refresh
+        ("ACT",  "0011"), # Activate
+        ("RD",   "0101"), # Read
+        ("WR",   "0100"), # Write
+        ("ZQCS", "0110"), # ZQCS
+    ]
+
+    RULES = [
+        # tRP
+        ("PRE",  "ACT", "tRP"),
+        ("PRE",  "REF", "tRP"),
+        # tRCD
+        ("ACT",  "WR",  "tRCD"),
+        ("ACT",  "RD",  "tRCD"),
+        # tRAS
+        ("ACT",  "PRE", "tRAS"),
+        # tRFC
+        ("REF",  "PRE", "tRFC"),
+        ("REF",  "ACT", "tRFC"),
+        # tCCD
+        ("WR",   "RD",  "tCCD"),
+        ("WR",   "WR",  "tCCD"),
+        ("RD",   "RD",  "tCCD"),
+        ("RD",   "WR",  "tCCD"),
+        # tRC
+        ("ACT",  "ACT", "tRC"),
+        # tWR
+        ("WR",   "PRE", "tWR"),
+        # tWTR
+        ("WR",   "RD",  "tWTR"),
+        # tZQCS
+        ("ZQCS", "ACT", "tZQCS"),
+    ]
+
+    def add_cmds(self):
+        self.cmds = {}
+        for idx, (name, pattern) in enumerate(self.CMDS):
+            self.cmds[name] = SDRAMCMD(name, int(pattern, 2), idx)
+
+    def add_rule(self, prev, curr, delay):
+        if not isinstance(delay, int):
+            delay = self.timings[delay]
+        self.rules.append(TimingRule(prev, curr, delay))
+
+    def add_rules(self):
+        self.rules = []
+        for rule in self.RULES:
+            self.add_rule(*rule)
+
+    # Convert ns to ps
+    def ns_to_ps(self, val):
+        return int(val * 1e3)
+
+    def ck_ns_to_ps(self, val, tck):
+        c, t = val
+        c = 0 if c is None else c * tck
+        t = 0 if t is None else t
+        return self.ns_to_ps(max(c, t))
+
+    def prepare_timings(self, timings, refresh_mode, memtype):
+        CK_NS = ["tRFC", "tWTR", "tFAW", "tCCD", "tRRD", "tZQCS"]
+        REF   = ["tREFI", "tRFC"]
+        self.timings = timings
+        new_timings  = {}
+
+        tck = self.timings["tCK"]
+
+        for key, val in self.timings.items():
+            if refresh_mode is not None and key in REF:
+                val = val[refresh_mode]
+
+            if val is None:
+                val = 0
+            elif key in CK_NS:
+                val = self.ck_ns_to_ps(val, tck)
+            else:
+                val = self.ns_to_ps(val)
+
+            new_timings[key] = val
+
+        new_timings["tRC"] = new_timings["tRAS"] + new_timings["tRP"]
+
+        # Adjust timings relative to write burst - tWR & tWTR
+        wrburst = burst_lengths[memtype] if memtype == "SDR" else burst_lengths[memtype] // 2
+        wrburst = (new_timings["tCK"] * (wrburst - 1))
+        new_timings["tWR"]  = new_timings["tWR"]  + wrburst
+        new_timings["tWTR"] = new_timings["tWTR"] + wrburst
+
+        self.timings = new_timings
+
+    def __init__(self, dfi, nbanks, nphases, timings, refresh_mode, memtype, verbose=False):
+        self.prepare_timings(timings, refresh_mode, memtype)
+        self.add_cmds()
+        self.add_rules()
+
+        cnt = Signal(64)
+        self.sync += cnt.eq(cnt + nphases)
+
+        phases = dfi.phases
+
+        last_cmd_ps = [[Signal.like(cnt) for _ in range(len(self.cmds))] for _ in range(nbanks)]
+        last_cmd    = [Signal(4) for i in range(nbanks)]
+
+        act_ps   = Array([Signal().like(cnt) for i in range(4)])
+        act_curr = Signal(range(4))
+
+        ref_issued = Signal(nphases)
+
+        for np, phase in enumerate(phases):
+            ps = Signal().like(cnt)
+            self.comb += ps.eq((cnt + np)*self.timings["tCK"])
+            state = Signal(4)
+            self.comb += state.eq(Cat(phase.we_n, phase.cas_n, phase.ras_n, phase.cs_n))
+            all_banks = Signal()
+
+            self.comb += all_banks.eq(
+                (self.cmds["REF"].enc == state) |
+                ((self.cmds["PRE"].enc == state) & phase.address[10])
+            )
+
+            # tREFI
+            self.comb += ref_issued[np].eq(self.cmds["REF"].enc == state)
+
+            # Print debug information
+            if verbose:
+                for _, cmd in self.cmds.items():
+                    self.sync += [
+                        If(state == cmd.enc,
+                            If(all_banks,
+                                Display("[%016dps] P%0d " + cmd.name, ps, np)
+                            ).Else(
+                                Display("[%016dps] P%0d B%0d " + cmd.name, ps, np, phase.bank)
+                            )
+                        )
+                    ]
+
+            # Bank command monitoring
+            for i in range(nbanks):
+                for _, curr in self.cmds.items():
+                    cmd_recv = Signal()
+                    self.comb += cmd_recv.eq(((phase.bank == i) | all_banks) & (state == curr.enc))
+
+                    # Checking rules from self.rules
+                    for _, prev in self.cmds.items():
+                        for rule in self.rules:
+                            if rule.prev == prev.name and rule.curr == curr.name:
+                                self.sync += [
+                                    If(cmd_recv & (last_cmd[i] == prev.enc) &
+                                       (ps < (last_cmd_ps[i][prev.idx] + rule.delay)),
+                                        Display("[%016dps] {} violation on bank %0d".format(rule.name), ps, i)
+                                    )
+                                ]
+
+                    # Save command timestamp in an array
+                    self.sync += If(cmd_recv, last_cmd_ps[i][curr.idx].eq(ps), last_cmd[i].eq(state))
+
+                    # tRRD & tFAW
+                    if curr.name == "ACT":
+                        act_next = Signal().like(act_curr)
+                        self.comb += act_next.eq(act_curr+1)
+
+                        # act_curr points to newest ACT timestamp
+                        self.sync += [
+                            If(cmd_recv & (ps < (act_ps[act_curr] + self.timings["tRRD"])),
+                                Display("[%016dps] tRRD violation on bank %0d", ps, i)
+                            )
+                        ]
+
+                        # act_next points to the oldest ACT timestamp
+                        self.sync += [
+                            If(cmd_recv & (ps < (act_ps[act_next] + self.timings["tFAW"])),
+                                Display("[%016dps] tFAW violation on bank %0d", ps, i)
+                            )
+                        ]
+
+                        # Save ACT timestamp in a circular buffer
+                        self.sync += If(cmd_recv, act_ps[act_next].eq(ps), act_curr.eq(act_next))
+
+        # tREFI
+        ref_ps      = Signal().like(cnt)
+        ref_ps_mod  = Signal().like(cnt)
+        #ref_ps_diff = Signal(range(-2**63, 2**63))
+        ref_ps_diff = Signal(signed(64))
+        curr_diff   = Signal().like(ref_ps_diff)
+
+        self.comb += curr_diff.eq(ps - (ref_ps + self.timings["tREFI"]))
+
+        # Work in 64ms periods
+        self.sync += [
+            If(ref_ps_mod < int(64e9),
+                ref_ps_mod.eq(ref_ps_mod + nphases * self.timings["tCK"])
+            ).Else(
+                ref_ps_mod.eq(0)
+            )
+        ]
+
+        # Update timestamp and difference
+        self.sync += If(ref_issued != 0, ref_ps.eq(ps), ref_ps_diff.eq(ref_ps_diff - curr_diff))
+
+        self.sync += [
+            If((ref_ps_mod == 0) & (ref_ps_diff > 0),
+                Display("[%016dps] tREFI violation (64ms period): %0d", ps, ref_ps_diff)
+            )
+        ]
+
+        # Report any refresh periods longer than tREFI
+        if verbose:
+            ref_done = Signal()
+            self.sync += [
+                If(ref_issued != 0,
+                    ref_done.eq(1),
+                    If(~ref_done,
+                        Display("[%016dps] Late refresh", ps)
+                    )
+                )
+            ]
+
+            self.sync += [
+                If((curr_diff > 0) & ref_done & (ref_issued == 0),
+                    Display("[%016dps] tREFI violation", ps),
+                    ref_done.eq(0)
+                )
+            ]
+
+        # There is a maximum delay between refreshes on >=DDR
+        ref_limit = {"1x": 9, "2x": 17, "4x": 36}
+        if memtype != "SDR":
+            refresh_mode = "1x" if refresh_mode is None else refresh_mode
+            ref_done = Signal()
+            self.sync += If(ref_issued != 0, ref_done.eq(1))
+            self.sync += [
+                If((ref_issued == 0) & ref_done &
+                   (ref_ps > (ps + ref_limit[refresh_mode] * self.timings['tREFI'])),
+                    Display("[%016dps] tREFI violation (too many postponed refreshes)", ps),
+                    ref_done.eq(0)
+                )
+            ]
+
+class FakePHY(CompatModule):
+    def __prepare_bank_init_data(self, init, nbanks, nrows, ncols, data_width, address_mapping):
+        mem_size          = (self.settings.databits//8)*(nrows*ncols*nbanks)
+        bank_size         = mem_size // nbanks
+        column_size       = bank_size // nrows
+        model_bank_size   = bank_size // (data_width//8)
+        model_column_size = model_bank_size // nrows
+        model_data_ratio  = data_width // 32
+        data_width_bytes  = data_width // 8
+        bank_init         = [[] for i in range(nbanks)]
+
+        # Pad init if too short
+        if len(init)%data_width_bytes != 0:
+            init.extend([0]*(data_width_bytes-len(init)%data_width_bytes))
+
+
+        # Convert init data width from 32-bit to data_width if needed
+        if model_data_ratio > 1:
+            new_init = [0]*(len(init)//model_data_ratio)
+            for i in range(0, len(init), model_data_ratio):
+                ints = init[i:i+model_data_ratio]
+                strs = "".join("{:08x}".format(x) for x in reversed(ints))
+                new_init[i//model_data_ratio] = int(strs, 16)
+            init = new_init
+        elif model_data_ratio == 0:
+            assert data_width_bytes in [1, 2]
+            model_data_ratio = 4 // data_width_bytes
+            struct_unpack_patterns = {1: "4B", 2: "2H"}
+            new_init = [0]*int(len(init)*model_data_ratio)
+            for i in range(len(init)):
+                new_init[model_data_ratio*i:model_data_ratio*(i+1)] = struct.unpack(
+                    struct_unpack_patterns[data_width_bytes],
+                    struct.pack("I", init[i])
+                )[0:model_data_ratio]
+            init = new_init
+
+        if address_mapping == "ROW_BANK_COL":
+            for row in range(nrows):
+                for bank in range(nbanks):
+                    start = (row*nbanks*model_column_size + bank*model_column_size)
+                    end   = min(start + model_column_size, len(init))
+                    if start > len(init):
+                        break
+                    bank_init[bank].extend(init[start:end])
+        elif address_mapping == "BANK_ROW_COL":
+            for bank in range(nbanks):
+                start = bank*model_bank_size
+                end   = min(start + model_bank_size, len(init))
+                if start > len(init):
+                    break
+                bank_init[bank] = init[start:end]
+
+        return bank_init
+
+    def __init__(self, module, settings, clk_freq=100e6,
+        we_granularity         = 8,
+        init                   = [],
+        address_mapping        = "ROW_BANK_COL",
+        verbosity              = SDRAM_VERBOSE_OFF):
+
+        # Parameters -------------------------------------------------------------------------------
+        burst_length = {
+            "SDR":   1,
+            "DDR":   2,
+            "LPDDR": 2,
+            "DDR2":  2,
+            "DDR3":  2,
+            "DDR4":  2,
+            }[settings.memtype]
+
+        addressbits   = module.geom_settings.addressbits
+        bankbits      = module.geom_settings.bankbits
+        rowbits       = module.geom_settings.rowbits
+        colbits       = module.geom_settings.colbits
+
+        self.settings = settings
+        self.module   = module
+
+        # DFI Interface ----------------------------------------------------------------------------
+        self.dfi = Interface(
+            addressbits = addressbits,
+            bankbits    = bankbits,
+            nranks      = self.settings.nranks,
+            databits    = self.settings.dfi_databits,
+            nphases     = self.settings.nphases
+        )
+
+        # # #
+
+        nphases    = self.settings.nphases
+        nbanks     = 2**bankbits
+        nrows      = 2**rowbits
+        ncols      = 2**colbits
+        data_width = self.settings.dfi_databits*self.settings.nphases
+
+        # DFI phases -------------------------------------------------------------------------------
+        phases = [DFIPhaseModel(self.dfi, n) for n in range(self.settings.nphases)]
+        self.submodules += phases
+
+        # DFI timing checker -----------------------------------------------------------------------
+        if verbosity > SDRAM_VERBOSE_OFF:
+            timings = {"tCK": (1e9 / clk_freq) / nphases}
+
+            for name in _speedgrade_timings + _technology_timings:
+                timings[name] = self.module.get(name)
+
+            timing_checker = DFITimingsChecker(
+                dfi          = self.dfi,
+                nbanks       = nbanks,
+                nphases      = nphases,
+                timings      = timings,
+                refresh_mode = self.module.timing_settings.fine_refresh_mode,
+                memtype      = settings.memtype,
+                verbose      = verbosity > SDRAM_VERBOSE_DBG)
+            self.submodules += timing_checker
+
+        # Bank init data ---------------------------------------------------------------------------
+        bank_init  = [None for i in range(nbanks)]
+
+        if init:
+            bank_init = self.__prepare_bank_init_data(
+                init            = init,
+                nbanks          = nbanks,
+                nrows           = nrows,
+                ncols           = ncols,
+                data_width      = data_width,
+                address_mapping = address_mapping
+            )
+
+        # Banks ------------------------------------------------------------------------------------
+        banks = [BankModel(
+            data_width     = data_width,
+            nrows          = nrows,
+            ncols          = ncols,
+            burst_length   = burst_length,
+            nphases        = nphases,
+            we_granularity = we_granularity,
+            init           = bank_init[i]) for i in range(nbanks)]
+        self.submodules += banks
+
+        # Connect DFI phases to Banks (CMDs, Write datapath) ---------------------------------------
+        for nb, bank in enumerate(banks):
+            # Bank activate
+            activates = Signal(len(phases))
+            cases     = {}
+            for np, phase in enumerate(phases):
+                self.comb += activates[np].eq(phase.activate)
+                cases[2**np] = [
+                    bank.activate.eq(phase.bank == nb),
+                    bank.activate_row.eq(phase.address)
+                ]
+            self.comb += Case(activates, cases)
+
+            # Bank precharge
+            precharges = Signal(len(phases))
+            cases      = {}
+            for np, phase in enumerate(phases):
+                self.comb += precharges[np].eq(phase.precharge)
+                cases[2**np] = [
+                    bank.precharge.eq((phase.bank == nb) | phase.address[10])
+                ]
+            self.comb += Case(precharges, cases)
+
+            # Bank writes
+            bank_write = Signal()
+            bank_write_col = Signal(range(ncols))
+            writes = Signal(len(phases))
+            cases  = {}
+            for np, phase in enumerate(phases):
+                self.comb += writes[np].eq(phase.write)
+                cases[2**np] = [
+                    bank_write.eq(phase.bank == nb),
+                    bank_write_col.eq(phase.address)
+                ]
+            self.comb += Case(writes, cases)
+            self.comb += [
+                bank.write_data.eq(Cat(*[phase.wrdata for phase in phases])),
+                bank.write_mask.eq(Cat(*[phase.wrdata_mask for phase in phases]))
+            ]
+
+            # Simulate write latency
+            for i in range(self.settings.write_latency):
+                new_bank_write     = Signal()
+                new_bank_write_col = Signal(range(ncols))
+                self.sync += [
+                    new_bank_write.eq(bank_write),
+                    new_bank_write_col.eq(bank_write_col)
+                ]
+                bank_write = new_bank_write
+                bank_write_col = new_bank_write_col
+
+            self.comb += [
+                bank.write.eq(bank_write),
+                bank.write_col.eq(bank_write_col)
+            ]
+
+            # Bank reads
+            reads = Signal(len(phases))
+            cases = {}
+            for np, phase in enumerate(phases):
+                self.comb += reads[np].eq(phase.read)
+                cases[2**np] = [
+                    bank.read.eq(phase.bank == nb),
+                    bank.read_col.eq(phase.address)
+            ]
+            self.comb += Case(reads, cases)
+
+        # Connect Banks to DFI phases (CMDs, Read datapath) ----------------------------------------
+        banks_read      = Signal()
+        banks_read_data = Signal(data_width)
+        self.comb += [
+            banks_read.eq(reduce(or_, [bank.read for bank in banks])),
+            banks_read_data.eq(reduce(or_, [bank.read_data for bank in banks]))
+        ]
+
+        # Simulate read latency --------------------------------------------------------------------
+        for i in range(self.settings.read_latency):
+            new_banks_read      = Signal()
+            new_banks_read_data = Signal(data_width)
+            self.sync += [
+                new_banks_read.eq(banks_read),
+                new_banks_read_data.eq(banks_read_data)
+            ]
+            banks_read      = new_banks_read
+            banks_read_data = new_banks_read_data
+
+        self.comb += [
+            Cat(*[phase.rddata_valid for phase in phases]).eq(banks_read),
+            Cat(*[phase.rddata for phase in phases]).eq(banks_read_data)
+        ]