Add more memory tests
[gram.git] / gram / phy / fakephy.py
1 # This file is Copyright (c) 2015-2020 Florent Kermarrec <florent@enjoy-digital.fr>
2 # This file is Copyright (c) 2020 Antmicro <www.antmicro.com>
3 # License: BSD
4
5 # SDRAM simulation PHY at DFI level tested with SDR/DDR/DDR2/LPDDR/DDR3
6 # TODO:
7 # - add multirank support.
8
9 from nmigen import *
10 from nmigen.utils import log2_int
11
12 from gram.common import burst_lengths
13 from gram.phy.dfi import *
14 from gram.modules import _speedgrade_timings, _technology_timings
15
16 from functools import reduce
17 from operator import or_
18
19 import struct
20
21 SDRAM_VERBOSE_OFF = 0
22 SDRAM_VERBOSE_STD = 1
23 SDRAM_VERBOSE_DBG = 2
24
25 def Display(*args):
26 return Signal().eq(0)
27
28 def Assert(*args):
29 return Signal().eq(0)
30
31 # Bank Model ---------------------------------------------------------------------------------------
32
33 class BankModel(Elaboratable):
34 def __init__(self, data_width, nrows, ncols, burst_length, nphases, we_granularity, init):
35 self.activate = Signal()
36 self.activate_row = Signal(range(nrows))
37 self.precharge = Signal()
38
39 self.write = Signal()
40 self.write_col = Signal(range(ncols))
41 self.write_data = Signal(data_width)
42 self.write_mask = Signal(data_width//8)
43
44 self.read = Signal()
45 self.read_col = Signal(range(ncols))
46 self.read_data = Signal(data_width)
47 self.nphases = nphases
48 self.nrows = nrows
49 self.ncols = ncols
50 self.burst_length = burst_length
51 self.data_width = data_width
52 self.we_granularity = we_granularity
53 self.init = init
54
55 def elaborate(self, platform):
56 m = Module()
57
58 nrows = self.nrows
59 ncols = self.ncols
60 burst_length = self.burst_length
61 data_width = self.data_width
62 we_granularity = self.we_granularity
63 init = self.init
64
65 active = Signal()
66 row = Signal(range(nrows))
67
68 with m.If(self.precharge):
69 m.d.sync += active.eq(0)
70 with m.Elif(self.activate):
71 m.d.sync += [
72 active.eq(1),
73 row.eq(self.activate_row),
74 ]
75
76 bank_mem_len = nrows*ncols//(burst_length*self.nphases)
77 mem = Memory(width=data_width, depth=100, init=init)
78 write_port = mem.write_port(granularity=we_granularity)
79 read_port = mem.read_port(domain="comb")
80 m.submodules += read_port, write_port
81
82 wraddr = Signal(range(bank_mem_len))
83 rdaddr = Signal(range(bank_mem_len))
84
85 m.d.comb += [
86 wraddr.eq((row*ncols | self.write_col)[log2_int(burst_length*self.nphases):] % 100),
87 rdaddr.eq((row*ncols | self.read_col)[log2_int(burst_length*self.nphases):] % 100),
88 ]
89
90 with m.If(active):
91 m.d.comb += [
92 write_port.addr.eq(wraddr % 100),
93 write_port.data.eq(self.write_data),
94 ]
95
96 with m.If(we_granularity):
97 m.d.comb += write_port.en.eq(Repl(self.write, data_width//8) & ~self.write_mask)
98 with m.Else():
99 m.d.comb += write_port.en.eq(self.write)
100
101 with m.If(self.read):
102 m.d.comb += [
103 read_port.addr.eq(rdaddr),
104 self.read_data.eq(read_port.data),
105 ]
106
107 return m
108
109 # DFI Phase Model ----------------------------------------------------------------------------------
110
111 class DFIPhaseModel(Elaboratable):
112 def __init__(self, dfi, n):
113 self.phase = dfi.phases[n]
114
115 self.bank = self.phase.bank
116 self.address = self.phase.address
117
118 self.wrdata = self.phase.wrdata
119 self.wrdata_mask = self.phase.wrdata_mask
120
121 self.rddata = self.phase.rddata
122 self.rddata_valid = self.phase.rddata_valid
123
124 self.activate = Signal()
125 self.precharge = Signal()
126 self.write = Signal()
127 self.read = Signal()
128
129 def elaborate(self, platform):
130 m = Module()
131
132 with m.If(~self.phase.cs_n & ~self.phase.ras_n & self.phase.cas_n):
133 m.d.comb += [
134 self.activate.eq(self.phase.we_n),
135 self.precharge.eq(~self.phase.we_n),
136 ]
137
138 with m.If(~self.phase.cs_n & self.phase.ras_n & ~self.phase.cas_n):
139 m.d.comb += [
140 self.write.eq(~self.phase.we_n),
141 self.read.eq(self.phase.we_n),
142 ]
143
144 return m
145
146 # DFI Timings Checker ------------------------------------------------------------------------------
147
148 class SDRAMCMD:
149 def __init__(self, name: str, enc: int, idx: int):
150 self.name = name
151 self.enc = enc
152 self.idx = idx
153
154
155 class TimingRule:
156 def __init__(self, prev: str, curr: str, delay: int):
157 self.name = prev + "->" + curr
158 self.prev = prev
159 self.curr = curr
160 self.delay = delay
161
162
163 class DFITimingsChecker(Elaboratable):
164 CMDS = [
165 # Name, cs & ras & cas & we value
166 ("PRE", "0010"), # Precharge
167 ("REF", "0001"), # Self refresh
168 ("ACT", "0011"), # Activate
169 ("RD", "0101"), # Read
170 ("WR", "0100"), # Write
171 ("ZQCS", "0110"), # ZQCS
172 ]
173
174 RULES = [
175 # tRP
176 ("PRE", "ACT", "tRP"),
177 ("PRE", "REF", "tRP"),
178 # tRCD
179 ("ACT", "WR", "tRCD"),
180 ("ACT", "RD", "tRCD"),
181 # tRAS
182 ("ACT", "PRE", "tRAS"),
183 # tRFC
184 ("REF", "PRE", "tRFC"),
185 ("REF", "ACT", "tRFC"),
186 # tCCD
187 ("WR", "RD", "tCCD"),
188 ("WR", "WR", "tCCD"),
189 ("RD", "RD", "tCCD"),
190 ("RD", "WR", "tCCD"),
191 # tRC
192 ("ACT", "ACT", "tRC"),
193 # tWR
194 ("WR", "PRE", "tWR"),
195 # tWTR
196 ("WR", "RD", "tWTR"),
197 # tZQCS
198 ("ZQCS", "ACT", "tZQCS"),
199 ]
200
201 def add_cmds(self):
202 self.cmds = {}
203 for idx, (name, pattern) in enumerate(self.CMDS):
204 self.cmds[name] = SDRAMCMD(name, int(pattern, 2), idx)
205
206 def add_rule(self, prev, curr, delay):
207 if not isinstance(delay, int):
208 delay = self.timings[delay]
209 self.rules.append(TimingRule(prev, curr, delay))
210
211 def add_rules(self):
212 self.rules = []
213 for rule in self.RULES:
214 self.add_rule(*rule)
215
216 # Convert ns to ps
217 def ns_to_ps(self, val):
218 return int(val * 1e3)
219
220 def ck_ns_to_ps(self, val, tck):
221 c, t = val
222 c = 0 if c is None else c * tck
223 t = 0 if t is None else t
224 return self.ns_to_ps(max(c, t))
225
226 def prepare_timings(self, timings, refresh_mode, memtype):
227 CK_NS = ["tRFC", "tWTR", "tFAW", "tCCD", "tRRD", "tZQCS"]
228 REF = ["tREFI", "tRFC"]
229 self.timings = timings
230 new_timings = {}
231
232 tck = self.timings["tCK"]
233
234 for key, val in self.timings.items():
235 if refresh_mode is not None and key in REF:
236 val = val[refresh_mode]
237
238 if val is None:
239 val = 0
240 elif key in CK_NS:
241 val = self.ck_ns_to_ps(val, tck)
242 else:
243 val = self.ns_to_ps(val)
244
245 new_timings[key] = val
246
247 new_timings["tRC"] = new_timings["tRAS"] + new_timings["tRP"]
248
249 # Adjust timings relative to write burst - tWR & tWTR
250 wrburst = burst_lengths[memtype] if memtype == "SDR" else burst_lengths[memtype] // 2
251 wrburst = (new_timings["tCK"] * (wrburst - 1))
252 new_timings["tWR"] = new_timings["tWR"] + wrburst
253 new_timings["tWTR"] = new_timings["tWTR"] + wrburst
254
255 self.timings = new_timings
256
257 def __init__(self, dfi, nbanks, nphases, timings, refresh_mode, memtype, verbose=False):
258 self.prepare_timings(timings, refresh_mode, memtype)
259 self.add_cmds()
260 self.add_rules()
261 self.nphases = nphases
262 self.nbanks = nbanks
263 self.dfi = dfi
264 self.timings = timings
265 self.refresh_mode = refresh_mode
266 self.memtype = memtype
267 self.verbose = verbose
268
269 def elaborate(self, platform):
270 m = Module()
271
272 cnt = Signal(64)
273 m.d.sync += cnt.eq(cnt+self.nphases)
274
275 phases = self.dfi.phases
276 nbanks = self.nbanks
277 timings = self.timings
278 refresh_mode = self.refresh_mode
279 memtype = self.memtype
280 verbose = self.verbose
281
282 last_cmd_ps = [[Signal.like(cnt) for _ in range(len(self.cmds))] for _ in range(nbanks)]
283 last_cmd = [Signal(4) for i in range(nbanks)]
284
285 act_ps = Array([Signal().like(cnt) for i in range(4)])
286 act_curr = Signal(range(4))
287
288 ref_issued = Signal(self.nphases)
289
290 for np, phase in enumerate(phases):
291 ps = Signal().like(cnt)
292 m.d.comb += ps.eq((cnt + np)*int(self.timings["tCK"]))
293 state = Signal(4)
294 m.d.comb += state.eq(Cat(phase.we_n, phase.cas_n, phase.ras_n, phase.cs_n))
295 all_banks = Signal()
296
297 m.d.comb += all_banks.eq(
298 (self.cmds["REF"].enc == state) |
299 ((self.cmds["PRE"].enc == state) & phase.address[10])
300 )
301
302 # tREFI
303 m.d.comb += ref_issued[np].eq(self.cmds["REF"].enc == state)
304
305 # Print debug information
306 # TODO: find a way to bring back logging
307 # if verbose:
308 # for _, cmd in self.cmds.items():
309 # self.sync += [
310 # If(state == cmd.enc,
311 # If(all_banks,
312 # Display("[%016dps] P%0d " + cmd.name, ps, np)
313 # ).Else(
314 # Display("[%016dps] P%0d B%0d " + cmd.name, ps, np, phase.bank)
315 # )
316 # )
317 # ]
318
319 # Bank command monitoring
320 for i in range(nbanks):
321 for _, curr in self.cmds.items():
322 cmd_recv = Signal()
323 m.d.comb += cmd_recv.eq(((phase.bank == i) | all_banks) & (state == curr.enc))
324
325 # Checking rules from self.rules
326 for _, prev in self.cmds.items():
327 for rule in self.rules:
328 if rule.prev == prev.name and rule.curr == curr.name:
329 # Display("[%016dps] {} violation on bank %0d".format(rule.name), ps, i)
330 m.d.sync += Assert(~(cmd_recv & (last_cmd[i] == prev.enc) & (ps < (last_cmd_ps[i][prev.idx] + rule.delay))))
331
332 # Save command timestamp in an array
333 with m.If(cmd_recv):
334 m.d.comb += [
335 last_cmd_ps[i][curr.idx].eq(ps),
336 last_cmd[i].eq(state),
337 ]
338
339 # tRRD & tFAW
340 if curr.name == "ACT":
341 act_next = Signal().like(act_curr)
342 m.d.comb += act_next.eq(act_curr+1)
343
344 # act_curr points to newest ACT timestamp
345 #Display("[%016dps] tRRD violation on bank %0d", ps, i)
346 #m.d.sync += Assert(~(cmd_recv & (ps < (act_ps[act_curr] + int(self.timings["tRRD"])))))
347
348 # act_next points to the oldest ACT timestamp
349 #Display("[%016dps] tFAW violation on bank %0d", ps, i)
350 #m.d.sync += Assert(~(cmd_recv & (ps < (act_ps[act_next] + int(self.timings["tFAW"])))))
351
352 # Save ACT timestamp in a circular buffer
353 with m.If(cmd_recv):
354 m.d.sync += [
355 act_ps[act_next].eq(ps),
356 act_curr.eq(act_next),
357 ]
358
359 # tREFI
360 ref_ps = Signal().like(cnt)
361 ref_ps_mod = Signal().like(cnt)
362 ref_ps_diff = Signal(signed(64))
363 curr_diff = Signal().like(ref_ps_diff)
364
365 m.d.comb += curr_diff.eq(ps - (ref_ps + int(self.timings["tREFI"])))
366
367 # Work in 64ms periods
368 with m.If(ref_ps_mod < int(64e9)):
369 m.d.sync += ref_ps_mod.eq(ref_ps_mod + int(self.nphases * self.timings["tCK"]))
370 with m.Else():
371 m.d.sync += ref_ps_mod.eq(0)
372
373 # Update timestamp and difference
374 with m.If(ref_issued != 0):
375 m.d.sync += [
376 ref_ps.eq(ps),
377 ref_ps_diff.eq(ref_ps_diff - curr_diff),
378 ]
379
380 #Display("[%016dps] tREFI violation (64ms period): %0d", ps, ref_ps_diff)
381 m.d.sync += Assert(~((ref_ps_mod == 0) & (ref_ps_diff > 0)))
382
383 # Report any refresh periods longer than tREFI
384 # TODO: find a way to bring back logging
385 # if verbose:
386 # ref_done = Signal()
387 # self.sync += [
388 # If(ref_issued != 0,
389 # ref_done.eq(1),
390 # If(~ref_done,
391 # Display("[%016dps] Late refresh", ps)
392 # )
393 # )
394 # ]
395
396 # self.sync += [
397 # If((curr_diff > 0) & ref_done & (ref_issued == 0),
398 # Display("[%016dps] tREFI violation", ps),
399 # ref_done.eq(0)
400 # )
401 # ]
402
403 # There is a maximum delay between refreshes on >=DDR
404 ref_limit = {"1x": 9, "2x": 17, "4x": 36}
405 if memtype != "SDR":
406 refresh_mode = "1x" if refresh_mode is None else refresh_mode
407 ref_done = Signal()
408 with m.If(ref_issued != 0):
409 m.d.sync += ref_done.eq(1)
410
411 with m.If((ref_issued == 0) & ref_done &
412 (ref_ps > (ps + int(ref_limit[refresh_mode] * self.timings['tREFI'])))):
413 m.d.sync += ref_done.eq(0)
414 # self.sync += [
415 # If((ref_issued == 0) & ref_done &
416 # (ref_ps > (ps + ref_limit[refresh_mode] * self.timings['tREFI'])),
417 # Display("[%016dps] tREFI violation (too many postponed refreshes)", ps),
418 # ref_done.eq(0)
419 # )
420 # ]
421
422 return m
423
424 class FakePHY(Elaboratable):
425 def __prepare_bank_init_data(self, init, nbanks, nrows, ncols, data_width, address_mapping):
426 mem_size = (self.settings.databits//8)*(nrows*ncols*nbanks)
427 bank_size = mem_size // nbanks
428 column_size = bank_size // nrows
429 model_bank_size = bank_size // (data_width//8)
430 model_column_size = model_bank_size // nrows
431 model_data_ratio = data_width // 32
432 data_width_bytes = data_width // 8
433 bank_init = [[] for i in range(nbanks)]
434
435 # Pad init if too short
436 if len(init)%data_width_bytes != 0:
437 init.extend([0]*(data_width_bytes-len(init)%data_width_bytes))
438
439
440 # Convert init data width from 32-bit to data_width if needed
441 if model_data_ratio > 1:
442 new_init = [0]*(len(init)//model_data_ratio)
443 for i in range(0, len(init), model_data_ratio):
444 ints = init[i:i+model_data_ratio]
445 strs = "".join("{:08x}".format(x) for x in reversed(ints))
446 new_init[i//model_data_ratio] = int(strs, 16)
447 init = new_init
448 elif model_data_ratio == 0:
449 assert data_width_bytes in [1, 2]
450 model_data_ratio = 4 // data_width_bytes
451 struct_unpack_patterns = {1: "4B", 2: "2H"}
452 new_init = [0]*int(len(init)*model_data_ratio)
453 for i in range(len(init)):
454 new_init[model_data_ratio*i:model_data_ratio*(i+1)] = struct.unpack(
455 struct_unpack_patterns[data_width_bytes],
456 struct.pack("I", init[i])
457 )[0:model_data_ratio]
458 init = new_init
459
460 if address_mapping == "ROW_BANK_COL":
461 for row in range(nrows):
462 for bank in range(nbanks):
463 start = (row*nbanks*model_column_size + bank*model_column_size)
464 end = min(start + model_column_size, len(init))
465 if start > len(init):
466 break
467 bank_init[bank].extend(init[start:end])
468 elif address_mapping == "BANK_ROW_COL":
469 for bank in range(nbanks):
470 start = bank*model_bank_size
471 end = min(start + model_bank_size, len(init))
472 if start > len(init):
473 break
474 bank_init[bank] = init[start:end]
475
476 return bank_init
477
478 def __init__(self, module, settings, clk_freq=100e6,
479 we_granularity = 8,
480 init = [],
481 address_mapping = "ROW_BANK_COL",
482 verbosity = SDRAM_VERBOSE_OFF):
483
484 # Parameters -------------------------------------------------------------------------------
485 self.burst_length = {
486 "SDR": 1,
487 "DDR": 2,
488 "LPDDR": 2,
489 "DDR2": 2,
490 "DDR3": 2,
491 "DDR4": 2,
492 }[settings.memtype]
493
494 self.addressbits = module.geom_settings.addressbits
495 self.bankbits = module.geom_settings.bankbits
496 self.rowbits = module.geom_settings.rowbits
497 self.colbits = module.geom_settings.colbits
498
499 self.settings = settings
500 self.module = module
501
502 self.verbosity = verbosity
503 self.clk_freq = clk_freq
504 self.we_granularity = we_granularity
505
506 self.init = init
507
508 # DFI Interface ----------------------------------------------------------------------------
509 self.dfi = Interface(
510 addressbits = self.addressbits,
511 bankbits = self.bankbits,
512 nranks = self.settings.nranks,
513 databits = self.settings.dfi_databits,
514 nphases = self.settings.nphases
515 )
516
517 def elaborate(self, platform):
518 m = Module()
519
520 nphases = self.settings.nphases
521 nbanks = 2**self.bankbits
522 nrows = 2**self.rowbits
523 ncols = 2**self.colbits
524 data_width = self.settings.dfi_databits*self.settings.nphases
525
526 # DFI phases -------------------------------------------------------------------------------
527 phases = [DFIPhaseModel(self.dfi, n) for n in range(self.settings.nphases)]
528 m.submodules += phases
529
530 # DFI timing checker -----------------------------------------------------------------------
531 if self.verbosity > SDRAM_VERBOSE_OFF:
532 timings = {"tCK": (1e9 / self.clk_freq) / nphases}
533
534 for name in _speedgrade_timings + _technology_timings:
535 timings[name] = self.module.get(name)
536
537 timing_checker = DFITimingsChecker(
538 dfi = self.dfi,
539 nbanks = nbanks,
540 nphases = nphases,
541 timings = timings,
542 refresh_mode = self.module.timing_settings.fine_refresh_mode,
543 memtype = self.settings.memtype,
544 verbose = self.verbosity > SDRAM_VERBOSE_DBG)
545 m.submodules.timing_checker = timing_checker
546
547 # Bank init data ---------------------------------------------------------------------------
548 bank_init = [None for i in range(nbanks)]
549
550 if self.init:
551 bank_init = self.__prepare_bank_init_data(
552 init = self.init,
553 nbanks = nbanks,
554 nrows = nrows,
555 ncols = ncols,
556 data_width = data_width,
557 address_mapping = address_mapping
558 )
559
560 # Banks ------------------------------------------------------------------------------------
561 banks = [BankModel(
562 data_width = data_width,
563 nrows = nrows,
564 ncols = ncols,
565 burst_length = self.burst_length,
566 nphases = nphases,
567 we_granularity = self.we_granularity,
568 init = bank_init[i]) for i in range(nbanks)]
569 m.submodules += banks
570
571 # Connect DFI phases to Banks (CMDs, Write datapath) ---------------------------------------
572 for nb, bank in enumerate(banks):
573 # Bank activate
574 activates = Signal(len(phases))
575 with m.Switch(activates):
576 for np, phase in enumerate(phases):
577 m.d.comb += activates[np].eq(phase.activate)
578 with m.Case(2**np):
579 m.d.comb += [
580 bank.activate.eq(phase.bank == nb),
581 bank.activate_row.eq(phase.address)
582 ]
583
584 # Bank precharge
585 precharges = Signal(len(phases))
586 with m.Switch(precharges):
587 for np, phase in enumerate(phases):
588 m.d.comb += precharges[np].eq(phase.precharge)
589 with m.Case(2**np):
590 m.d.comb += bank.precharge.eq((phase.bank == nb) | phase.address[10])
591
592 # Bank writes
593 bank_write = Signal()
594 bank_write_col = Signal(range(ncols))
595 writes = Signal(len(phases))
596 with m.Switch(writes):
597 for np, phase in enumerate(phases):
598 m.d.comb += writes[np].eq(phase.write)
599 with m.Case(2**np):
600 m.d.comb += [
601 bank_write.eq(phase.bank == nb),
602 bank_write_col.eq(phase.address)
603 ]
604 m.d.comb += [
605 bank.write_data.eq(Cat(*[phase.wrdata for phase in phases])),
606 bank.write_mask.eq(Cat(*[phase.wrdata_mask for phase in phases]))
607 ]
608
609 # Simulate write latency
610 for i in range(self.settings.write_latency):
611 new_bank_write = Signal()
612 new_bank_write_col = Signal(range(ncols))
613 m.d.sync += [
614 new_bank_write.eq(bank_write),
615 new_bank_write_col.eq(bank_write_col)
616 ]
617 bank_write = new_bank_write
618 bank_write_col = new_bank_write_col
619
620 m.d.comb += [
621 bank.write.eq(bank_write),
622 bank.write_col.eq(bank_write_col)
623 ]
624
625 # Bank reads
626 reads = Signal(len(phases))
627 with m.Switch(reads):
628 for np, phase in enumerate(phases):
629 m.d.comb += reads[np].eq(phase.read)
630 with m.Case(2**np):
631 m.d.comb += [
632 bank.read.eq(phase.bank == nb),
633 bank.read_col.eq(phase.address),
634 ]
635
636 # Connect Banks to DFI phases (CMDs, Read datapath) ----------------------------------------
637 banks_read = Signal()
638 banks_read_data = Signal(data_width)
639 m.d.comb += [
640 banks_read.eq(reduce(or_, [bank.read for bank in banks])),
641 banks_read_data.eq(reduce(or_, [bank.read_data for bank in banks]))
642 ]
643
644 # Simulate read latency --------------------------------------------------------------------
645 for i in range(self.settings.read_latency):
646 new_banks_read = Signal()
647 new_banks_read_data = Signal(data_width)
648 m.d.sync += [
649 new_banks_read.eq(banks_read),
650 new_banks_read_data.eq(banks_read_data)
651 ]
652 banks_read = new_banks_read
653 banks_read_data = new_banks_read_data
654
655 m.d.comb += [
656 Cat(*[phase.rddata_valid for phase in phases]).eq(banks_read),
657 Cat(*[phase.rddata for phase in phases]).eq(banks_read_data)
658 ]
659
660 return m