Fix timings in simulation to prevent tDLLK errors
[gram.git] / gram / phy / fakephy.py
1 # This file is Copyright (c) 2015-2020 Florent Kermarrec <florent@enjoy-digital.fr>
2 # This file is Copyright (c) 2020 Antmicro <www.antmicro.com>
3 # License: BSD
4
5 # SDRAM simulation PHY at DFI level tested with SDR/DDR/DDR2/LPDDR/DDR3
6 # TODO:
7 # - add multirank support.
8
9 from nmigen import *
10 from nmigen.utils import log2_int
11
12 from gram.common import burst_lengths
13 from gram.phy.dfi import *
14 from gram.modules import _speedgrade_timings, _technology_timings
15
16 from functools import reduce
17 from operator import or_
18
19 import struct
20
21 SDRAM_VERBOSE_OFF = 0
22 SDRAM_VERBOSE_STD = 1
23 SDRAM_VERBOSE_DBG = 2
24
25 def Display(*args):
26 return Signal().eq(0)
27
28 def Assert(*args):
29 return Signal().eq(0)
30
31 # Bank Model ---------------------------------------------------------------------------------------
32
33 class BankModel(Elaboratable):
34 def __init__(self, data_width, nrows, ncols, burst_length, nphases, we_granularity, init):
35 self.activate = Signal()
36 self.activate_row = Signal(range(nrows))
37 self.precharge = Signal()
38
39 self.write = Signal()
40 self.write_col = Signal(range(ncols))
41 self.write_data = Signal(data_width)
42 self.write_mask = Signal(data_width//8)
43
44 self.read = Signal()
45 self.read_col = Signal(range(ncols))
46 self.read_data = Signal(data_width)
47 self.nphases = nphases
48 self.nrows = nrows
49 self.ncols = ncols
50 self.burst_length = burst_length
51 self.data_width = data_width
52 self.we_granularity = we_granularity
53 self.init = init
54
55 def elaborate(self, platform):
56 m = Module()
57
58 nrows = self.nrows
59 ncols = self.ncols
60 burst_length = self.burst_length
61 data_width = self.data_width
62 we_granularity = self.we_granularity
63 init = self.init
64
65 active = Signal()
66 row = Signal(range(nrows))
67
68 with m.If(self.precharge):
69 m.d.sync += active.eq(0)
70 with m.Elif(self.activate):
71 m.d.sync += [
72 active.eq(1),
73 row.eq(self.activate_row),
74 ]
75
76 bank_mem_len = nrows*ncols//(burst_length*self.nphases)
77 # mem = Memory(width=data_width, depth=bank_mem_len, init=init)
78 # write_port = mem.get_port(write_capable=True, we_granularity=we_granularity)
79 # read_port = mem.get_port(async_read=True)
80 # m.submodules += mem, read_port, write_port
81
82 wraddr = Signal(range(bank_mem_len))
83 rdaddr = Signal(range(bank_mem_len))
84
85 m.d.comb += [
86 wraddr.eq((row*ncols | self.write_col)[log2_int(burst_length*self.nphases):]),
87 rdaddr.eq((row*ncols | self.read_col)[log2_int(burst_length*self.nphases):]),
88 ]
89
90 with m.If(active):
91 # m.d.comb += [
92 # write_port.adr.eq(wraddr),
93 # write_port.dat_w.eq(self.write_data),
94 # ]
95
96 # with m.If(we_granularity):
97 # m.d.comb += write_port.we.eq(Replicate(self.write, data_width//8) & ~self.write_mask)
98 # with m.Else():
99 # m.d.comb += write_port.we.eq(self.write)
100
101 with m.If(self.read):
102 # m.d.comb += [
103 # read_port.adr.eq(rdaddr),
104 # self.read_data.eq(read_port.dat_r),
105 # ]
106 m.d.comb += self.read_data.eq(0xDEADBEEF)
107
108 return m
109
110 # DFI Phase Model ----------------------------------------------------------------------------------
111
112 class DFIPhaseModel(Elaboratable):
113 def __init__(self, dfi, n):
114 self.phase = dfi.phases[n]
115
116 self.bank = self.phase.bank
117 self.address = self.phase.address
118
119 self.wrdata = self.phase.wrdata
120 self.wrdata_mask = self.phase.wrdata_mask
121
122 self.rddata = self.phase.rddata
123 self.rddata_valid = self.phase.rddata_valid
124
125 self.activate = Signal()
126 self.precharge = Signal()
127 self.write = Signal()
128 self.read = Signal()
129
130 def elaborate(self, platform):
131 m = Module()
132
133 with m.If(~self.phase.cs_n & ~self.phase.ras_n & self.phase.cas_n):
134 m.d.comb += [
135 self.activate.eq(self.phase.we_n),
136 self.precharge.eq(~self.phase.we_n),
137 ]
138
139 with m.If(~self.phase.cs_n & self.phase.ras_n & ~self.phase.cas_n):
140 m.d.comb += [
141 self.write.eq(~self.phase.we_n),
142 self.read.eq(self.phase.we_n),
143 ]
144
145 return m
146
147 # DFI Timings Checker ------------------------------------------------------------------------------
148
149 class SDRAMCMD:
150 def __init__(self, name: str, enc: int, idx: int):
151 self.name = name
152 self.enc = enc
153 self.idx = idx
154
155
156 class TimingRule:
157 def __init__(self, prev: str, curr: str, delay: int):
158 self.name = prev + "->" + curr
159 self.prev = prev
160 self.curr = curr
161 self.delay = delay
162
163
164 class DFITimingsChecker(Elaboratable):
165 CMDS = [
166 # Name, cs & ras & cas & we value
167 ("PRE", "0010"), # Precharge
168 ("REF", "0001"), # Self refresh
169 ("ACT", "0011"), # Activate
170 ("RD", "0101"), # Read
171 ("WR", "0100"), # Write
172 ("ZQCS", "0110"), # ZQCS
173 ]
174
175 RULES = [
176 # tRP
177 ("PRE", "ACT", "tRP"),
178 ("PRE", "REF", "tRP"),
179 # tRCD
180 ("ACT", "WR", "tRCD"),
181 ("ACT", "RD", "tRCD"),
182 # tRAS
183 ("ACT", "PRE", "tRAS"),
184 # tRFC
185 ("REF", "PRE", "tRFC"),
186 ("REF", "ACT", "tRFC"),
187 # tCCD
188 ("WR", "RD", "tCCD"),
189 ("WR", "WR", "tCCD"),
190 ("RD", "RD", "tCCD"),
191 ("RD", "WR", "tCCD"),
192 # tRC
193 ("ACT", "ACT", "tRC"),
194 # tWR
195 ("WR", "PRE", "tWR"),
196 # tWTR
197 ("WR", "RD", "tWTR"),
198 # tZQCS
199 ("ZQCS", "ACT", "tZQCS"),
200 ]
201
202 def add_cmds(self):
203 self.cmds = {}
204 for idx, (name, pattern) in enumerate(self.CMDS):
205 self.cmds[name] = SDRAMCMD(name, int(pattern, 2), idx)
206
207 def add_rule(self, prev, curr, delay):
208 if not isinstance(delay, int):
209 delay = self.timings[delay]
210 self.rules.append(TimingRule(prev, curr, delay))
211
212 def add_rules(self):
213 self.rules = []
214 for rule in self.RULES:
215 self.add_rule(*rule)
216
217 # Convert ns to ps
218 def ns_to_ps(self, val):
219 return int(val * 1e3)
220
221 def ck_ns_to_ps(self, val, tck):
222 c, t = val
223 c = 0 if c is None else c * tck
224 t = 0 if t is None else t
225 return self.ns_to_ps(max(c, t))
226
227 def prepare_timings(self, timings, refresh_mode, memtype):
228 CK_NS = ["tRFC", "tWTR", "tFAW", "tCCD", "tRRD", "tZQCS"]
229 REF = ["tREFI", "tRFC"]
230 self.timings = timings
231 new_timings = {}
232
233 tck = self.timings["tCK"]
234
235 for key, val in self.timings.items():
236 if refresh_mode is not None and key in REF:
237 val = val[refresh_mode]
238
239 if val is None:
240 val = 0
241 elif key in CK_NS:
242 val = self.ck_ns_to_ps(val, tck)
243 else:
244 val = self.ns_to_ps(val)
245
246 new_timings[key] = val
247
248 new_timings["tRC"] = new_timings["tRAS"] + new_timings["tRP"]
249
250 # Adjust timings relative to write burst - tWR & tWTR
251 wrburst = burst_lengths[memtype] if memtype == "SDR" else burst_lengths[memtype] // 2
252 wrburst = (new_timings["tCK"] * (wrburst - 1))
253 new_timings["tWR"] = new_timings["tWR"] + wrburst
254 new_timings["tWTR"] = new_timings["tWTR"] + wrburst
255
256 self.timings = new_timings
257
258 def __init__(self, dfi, nbanks, nphases, timings, refresh_mode, memtype, verbose=False):
259 self.prepare_timings(timings, refresh_mode, memtype)
260 self.add_cmds()
261 self.add_rules()
262 self.nphases = nphases
263 self.nbanks = nbanks
264 self.dfi = dfi
265 self.timings = timings
266 self.refresh_mode = refresh_mode
267 self.memtype = memtype
268 self.verbose = verbose
269
270 def elaborate(self, platform):
271 m = Module()
272
273 cnt = Signal(64)
274 m.d.sync += cnt.eq(cnt+self.nphases)
275
276 phases = self.dfi.phases
277 nbanks = self.nbanks
278 timings = self.timings
279 refresh_mode = self.refresh_mode
280 memtype = self.memtype
281 verbose = self.verbose
282
283 last_cmd_ps = [[Signal.like(cnt) for _ in range(len(self.cmds))] for _ in range(nbanks)]
284 last_cmd = [Signal(4) for i in range(nbanks)]
285
286 act_ps = Array([Signal().like(cnt) for i in range(4)])
287 act_curr = Signal(range(4))
288
289 ref_issued = Signal(self.nphases)
290
291 for np, phase in enumerate(phases):
292 ps = Signal().like(cnt)
293 m.d.comb += ps.eq((cnt + np)*int(self.timings["tCK"]))
294 state = Signal(4)
295 m.d.comb += state.eq(Cat(phase.we_n, phase.cas_n, phase.ras_n, phase.cs_n))
296 all_banks = Signal()
297
298 m.d.comb += all_banks.eq(
299 (self.cmds["REF"].enc == state) |
300 ((self.cmds["PRE"].enc == state) & phase.address[10])
301 )
302
303 # tREFI
304 m.d.comb += ref_issued[np].eq(self.cmds["REF"].enc == state)
305
306 # Print debug information
307 # TODO: find a way to bring back logging
308 # if verbose:
309 # for _, cmd in self.cmds.items():
310 # self.sync += [
311 # If(state == cmd.enc,
312 # If(all_banks,
313 # Display("[%016dps] P%0d " + cmd.name, ps, np)
314 # ).Else(
315 # Display("[%016dps] P%0d B%0d " + cmd.name, ps, np, phase.bank)
316 # )
317 # )
318 # ]
319
320 # Bank command monitoring
321 for i in range(nbanks):
322 for _, curr in self.cmds.items():
323 cmd_recv = Signal()
324 m.d.comb += cmd_recv.eq(((phase.bank == i) | all_banks) & (state == curr.enc))
325
326 # Checking rules from self.rules
327 for _, prev in self.cmds.items():
328 for rule in self.rules:
329 if rule.prev == prev.name and rule.curr == curr.name:
330 # Display("[%016dps] {} violation on bank %0d".format(rule.name), ps, i)
331 m.d.sync += Assert(~(cmd_recv & (last_cmd[i] == prev.enc) & (ps < (last_cmd_ps[i][prev.idx] + rule.delay))))
332
333 # Save command timestamp in an array
334 with m.If(cmd_recv):
335 m.d.comb += [
336 last_cmd_ps[i][curr.idx].eq(ps),
337 last_cmd[i].eq(state),
338 ]
339
340 # tRRD & tFAW
341 if curr.name == "ACT":
342 act_next = Signal().like(act_curr)
343 m.d.comb += act_next.eq(act_curr+1)
344
345 # act_curr points to newest ACT timestamp
346 #Display("[%016dps] tRRD violation on bank %0d", ps, i)
347 #m.d.sync += Assert(~(cmd_recv & (ps < (act_ps[act_curr] + int(self.timings["tRRD"])))))
348
349 # act_next points to the oldest ACT timestamp
350 #Display("[%016dps] tFAW violation on bank %0d", ps, i)
351 #m.d.sync += Assert(~(cmd_recv & (ps < (act_ps[act_next] + int(self.timings["tFAW"])))))
352
353 # Save ACT timestamp in a circular buffer
354 with m.If(cmd_recv):
355 m.d.sync += [
356 act_ps[act_next].eq(ps),
357 act_curr.eq(act_next),
358 ]
359
360 # tREFI
361 ref_ps = Signal().like(cnt)
362 ref_ps_mod = Signal().like(cnt)
363 ref_ps_diff = Signal(signed(64))
364 curr_diff = Signal().like(ref_ps_diff)
365
366 m.d.comb += curr_diff.eq(ps - (ref_ps + int(self.timings["tREFI"])))
367
368 # Work in 64ms periods
369 with m.If(ref_ps_mod < int(64e9)):
370 m.d.sync += ref_ps_mod.eq(ref_ps_mod + int(self.nphases * self.timings["tCK"]))
371 with m.Else():
372 m.d.sync += ref_ps_mod.eq(0)
373
374 # Update timestamp and difference
375 with m.If(ref_issued != 0):
376 m.d.sync += [
377 ref_ps.eq(ps),
378 ref_ps_diff.eq(ref_ps_diff - curr_diff),
379 ]
380
381 #Display("[%016dps] tREFI violation (64ms period): %0d", ps, ref_ps_diff)
382 m.d.sync += Assert(~((ref_ps_mod == 0) & (ref_ps_diff > 0)))
383
384 # Report any refresh periods longer than tREFI
385 # TODO: find a way to bring back logging
386 # if verbose:
387 # ref_done = Signal()
388 # self.sync += [
389 # If(ref_issued != 0,
390 # ref_done.eq(1),
391 # If(~ref_done,
392 # Display("[%016dps] Late refresh", ps)
393 # )
394 # )
395 # ]
396
397 # self.sync += [
398 # If((curr_diff > 0) & ref_done & (ref_issued == 0),
399 # Display("[%016dps] tREFI violation", ps),
400 # ref_done.eq(0)
401 # )
402 # ]
403
404 # There is a maximum delay between refreshes on >=DDR
405 ref_limit = {"1x": 9, "2x": 17, "4x": 36}
406 if memtype != "SDR":
407 refresh_mode = "1x" if refresh_mode is None else refresh_mode
408 ref_done = Signal()
409 with m.If(ref_issued != 0):
410 m.d.sync += ref_done.eq(1)
411
412 with m.If((ref_issued == 0) & ref_done &
413 (ref_ps > (ps + int(ref_limit[refresh_mode] * self.timings['tREFI'])))):
414 m.d.sync += ref_done.eq(0)
415 # self.sync += [
416 # If((ref_issued == 0) & ref_done &
417 # (ref_ps > (ps + ref_limit[refresh_mode] * self.timings['tREFI'])),
418 # Display("[%016dps] tREFI violation (too many postponed refreshes)", ps),
419 # ref_done.eq(0)
420 # )
421 # ]
422
423 return m
424
425 class FakePHY(Elaboratable):
426 def __prepare_bank_init_data(self, init, nbanks, nrows, ncols, data_width, address_mapping):
427 mem_size = (self.settings.databits//8)*(nrows*ncols*nbanks)
428 bank_size = mem_size // nbanks
429 column_size = bank_size // nrows
430 model_bank_size = bank_size // (data_width//8)
431 model_column_size = model_bank_size // nrows
432 model_data_ratio = data_width // 32
433 data_width_bytes = data_width // 8
434 bank_init = [[] for i in range(nbanks)]
435
436 # Pad init if too short
437 if len(init)%data_width_bytes != 0:
438 init.extend([0]*(data_width_bytes-len(init)%data_width_bytes))
439
440
441 # Convert init data width from 32-bit to data_width if needed
442 if model_data_ratio > 1:
443 new_init = [0]*(len(init)//model_data_ratio)
444 for i in range(0, len(init), model_data_ratio):
445 ints = init[i:i+model_data_ratio]
446 strs = "".join("{:08x}".format(x) for x in reversed(ints))
447 new_init[i//model_data_ratio] = int(strs, 16)
448 init = new_init
449 elif model_data_ratio == 0:
450 assert data_width_bytes in [1, 2]
451 model_data_ratio = 4 // data_width_bytes
452 struct_unpack_patterns = {1: "4B", 2: "2H"}
453 new_init = [0]*int(len(init)*model_data_ratio)
454 for i in range(len(init)):
455 new_init[model_data_ratio*i:model_data_ratio*(i+1)] = struct.unpack(
456 struct_unpack_patterns[data_width_bytes],
457 struct.pack("I", init[i])
458 )[0:model_data_ratio]
459 init = new_init
460
461 if address_mapping == "ROW_BANK_COL":
462 for row in range(nrows):
463 for bank in range(nbanks):
464 start = (row*nbanks*model_column_size + bank*model_column_size)
465 end = min(start + model_column_size, len(init))
466 if start > len(init):
467 break
468 bank_init[bank].extend(init[start:end])
469 elif address_mapping == "BANK_ROW_COL":
470 for bank in range(nbanks):
471 start = bank*model_bank_size
472 end = min(start + model_bank_size, len(init))
473 if start > len(init):
474 break
475 bank_init[bank] = init[start:end]
476
477 return bank_init
478
479 def __init__(self, module, settings, clk_freq=100e6,
480 we_granularity = 8,
481 init = [],
482 address_mapping = "ROW_BANK_COL",
483 verbosity = SDRAM_VERBOSE_OFF):
484
485 # Parameters -------------------------------------------------------------------------------
486 self.burst_length = {
487 "SDR": 1,
488 "DDR": 2,
489 "LPDDR": 2,
490 "DDR2": 2,
491 "DDR3": 2,
492 "DDR4": 2,
493 }[settings.memtype]
494
495 self.addressbits = module.geom_settings.addressbits
496 self.bankbits = module.geom_settings.bankbits
497 self.rowbits = module.geom_settings.rowbits
498 self.colbits = module.geom_settings.colbits
499
500 self.settings = settings
501 self.module = module
502
503 self.verbosity = verbosity
504 self.clk_freq = clk_freq
505 self.we_granularity = we_granularity
506
507 self.init = init
508
509 # DFI Interface ----------------------------------------------------------------------------
510 self.dfi = Interface(
511 addressbits = self.addressbits,
512 bankbits = self.bankbits,
513 nranks = self.settings.nranks,
514 databits = self.settings.dfi_databits,
515 nphases = self.settings.nphases
516 )
517
518 def elaborate(self, platform):
519 m = Module()
520
521 nphases = self.settings.nphases
522 nbanks = 2**self.bankbits
523 nrows = 2**self.rowbits
524 ncols = 2**self.colbits
525 data_width = self.settings.dfi_databits*self.settings.nphases
526
527 # DFI phases -------------------------------------------------------------------------------
528 phases = [DFIPhaseModel(self.dfi, n) for n in range(self.settings.nphases)]
529 m.submodules += phases
530
531 # DFI timing checker -----------------------------------------------------------------------
532 if self.verbosity > SDRAM_VERBOSE_OFF:
533 timings = {"tCK": (1e9 / self.clk_freq) / nphases}
534
535 for name in _speedgrade_timings + _technology_timings:
536 timings[name] = self.module.get(name)
537
538 timing_checker = DFITimingsChecker(
539 dfi = self.dfi,
540 nbanks = nbanks,
541 nphases = nphases,
542 timings = timings,
543 refresh_mode = self.module.timing_settings.fine_refresh_mode,
544 memtype = self.settings.memtype,
545 verbose = self.verbosity > SDRAM_VERBOSE_DBG)
546 m.submodules += timing_checker
547
548 # Bank init data ---------------------------------------------------------------------------
549 bank_init = [None for i in range(nbanks)]
550
551 if self.init:
552 bank_init = self.__prepare_bank_init_data(
553 init = self.init,
554 nbanks = nbanks,
555 nrows = nrows,
556 ncols = ncols,
557 data_width = data_width,
558 address_mapping = address_mapping
559 )
560
561 # Banks ------------------------------------------------------------------------------------
562 banks = [BankModel(
563 data_width = data_width,
564 nrows = nrows,
565 ncols = ncols,
566 burst_length = self.burst_length,
567 nphases = nphases,
568 we_granularity = self.we_granularity,
569 init = bank_init[i]) for i in range(nbanks)]
570 m.submodules += banks
571
572 # Connect DFI phases to Banks (CMDs, Write datapath) ---------------------------------------
573 for nb, bank in enumerate(banks):
574 # Bank activate
575 activates = Signal(len(phases))
576 with m.Switch(activates):
577 for np, phase in enumerate(phases):
578 m.d.comb += activates[np].eq(phase.activate)
579 with m.Case(2**np):
580 m.d.comb += [
581 bank.activate.eq(phase.bank == nb),
582 bank.activate_row.eq(phase.address)
583 ]
584
585 # Bank precharge
586 precharges = Signal(len(phases))
587 with m.Switch(precharges):
588 for np, phase in enumerate(phases):
589 m.d.comb += precharges[np].eq(phase.precharge)
590 with m.Case(2**np):
591 m.d.comb += bank.precharge.eq((phase.bank == nb) | phase.address[10])
592
593 # Bank writes
594 bank_write = Signal()
595 bank_write_col = Signal(range(ncols))
596 writes = Signal(len(phases))
597 with m.Switch(writes):
598 for np, phase in enumerate(phases):
599 m.d.comb += writes[np].eq(phase.write)
600 with m.Case(2**np):
601 m.d.comb += [
602 bank_write.eq(phase.bank == nb),
603 bank_write_col.eq(phase.address)
604 ]
605 m.d.comb += [
606 bank.write_data.eq(Cat(*[phase.wrdata for phase in phases])),
607 bank.write_mask.eq(Cat(*[phase.wrdata_mask for phase in phases]))
608 ]
609
610 # Simulate write latency
611 for i in range(self.settings.write_latency):
612 new_bank_write = Signal()
613 new_bank_write_col = Signal(range(ncols))
614 m.d.sync += [
615 new_bank_write.eq(bank_write),
616 new_bank_write_col.eq(bank_write_col)
617 ]
618 bank_write = new_bank_write
619 bank_write_col = new_bank_write_col
620
621 m.d.comb += [
622 bank.write.eq(bank_write),
623 bank.write_col.eq(bank_write_col)
624 ]
625
626 # Bank reads
627 reads = Signal(len(phases))
628 with m.Switch(reads):
629 for np, phase in enumerate(phases):
630 m.d.comb += reads[np].eq(phase.read)
631 with m.Case(2**np):
632 m.d.comb += [
633 bank.read.eq(phase.bank == nb),
634 bank.read_col.eq(phase.address),
635 ]
636
637 # Connect Banks to DFI phases (CMDs, Read datapath) ----------------------------------------
638 banks_read = Signal()
639 banks_read_data = Signal(data_width)
640 m.d.comb += [
641 banks_read.eq(reduce(or_, [bank.read for bank in banks])),
642 banks_read_data.eq(reduce(or_, [bank.read_data for bank in banks]))
643 ]
644
645 # Simulate read latency --------------------------------------------------------------------
646 for i in range(self.settings.read_latency):
647 new_banks_read = Signal()
648 new_banks_read_data = Signal(data_width)
649 m.d.sync += [
650 new_banks_read.eq(banks_read),
651 new_banks_read_data.eq(banks_read_data)
652 ]
653 banks_read = new_banks_read
654 banks_read_data = new_banks_read_data
655
656 m.d.comb += [
657 Cat(*[phase.rddata_valid for phase in phases]).eq(banks_read),
658 Cat(*[phase.rddata for phase in phases]).eq(banks_read_data)
659 ]
660
661 return m