Remove wants_zqcs signal
[gram.git] / gram / core / refresher.py
index 14f152afaca65ab7e51356df2e98f7cedb30f055..9ab43592a8bc924a7a1281fc031d2e0f5d1f5169 100644 (file)
@@ -6,7 +6,8 @@
 """LiteDRAM Refresher."""
 
 from nmigen import *
-from nmigen.utils import bits_for, log2_int
+from nmigen.utils import log2_int
+from nmigen.asserts import Assert, Assume
 
 from gram.core.multiplexer import *
 from gram.compat import Timeline
@@ -25,14 +26,14 @@ class RefreshExecuter(Elaboratable):
     - Wait tRFC
     """
 
-    def __init__(self, trp, trfc):
+    def __init__(self, abits, babits, trp, trfc):
         self.start = Signal()
         self.done = Signal()
         self._trp = trp
         self._trfc = trfc
 
-        self.a = Signal()
-        self.ba = Signal()
+        self.a = Signal(abits)
+        self.ba = Signal(babits)
         self.cas = Signal()
         self.ras = Signal()
         self.we = Signal()
@@ -43,7 +44,7 @@ class RefreshExecuter(Elaboratable):
         trp = self._trp
         trfc = self._trfc
 
-        tl = Timeline([
+        m.submodules.timeline = tl = Timeline([
             # Precharge All
             (0, [
                 self.a.eq(2**10),
@@ -70,7 +71,6 @@ class RefreshExecuter(Elaboratable):
                 self.done.eq(1),
             ]),
         ])
-        m.submodules += tl
         m.d.comb += tl.trigger.eq(self.start)
 
         return m
@@ -84,16 +84,18 @@ class RefreshSequencer(Elaboratable):
     Sequence N refreshs to the DRAM.
     """
 
-    def __init__(self, trp, trfc, postponing=1):
+    def __init__(self, abits, babits, trp, trfc, postponing=1):
         self.start = Signal()
         self.done = Signal()
 
         self._trp = trp
         self._trfc = trfc
         self._postponing = postponing
+        self._abits = abits
+        self._babits = babits
 
-        self.a = Signal()
-        self.ba = Signal()
+        self.a = Signal(abits)
+        self.ba = Signal(babits)
         self.cas = Signal()
         self.ras = Signal()
         self.we = Signal()
@@ -101,8 +103,7 @@ class RefreshSequencer(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        executer = RefreshExecuter(self._trp, self._trfc)
-        m.submodules += executer
+        m.submodules.executer = executer = RefreshExecuter(self._abits, self._babits, self._trp, self._trfc)
         m.d.comb += [
             self.a.eq(executer.a),
             self.ba.eq(executer.ba),
@@ -111,18 +112,42 @@ class RefreshSequencer(Elaboratable):
             self.we.eq(executer.we),
         ]
 
-        count = Signal(bits_for(self._postponing), reset=self._postponing-1)
+        countEqZero = Signal(reset=(self._postponing <= 1))
+        countDiffZero = Signal(reset=(self._postponing > 1))
+
+        count = Signal(range(self._postponing), reset=self._postponing-1)
         with m.If(self.start):
-            m.d.sync += count.eq(count.reset)
+            m.d.sync += [
+                count.eq(count.reset),
+                countEqZero.eq(self._postponing <= 1),
+                countDiffZero.eq(self._postponing > 1),
+            ]
         with m.Elif(executer.done):
             with m.If(count != 0):
                 m.d.sync += count.eq(count-1)
 
+            with m.If(count == 1):
+                m.d.sync += [
+                    countEqZero.eq(1),
+                    countDiffZero.eq(0),
+                ]
+            with m.Else():
+                m.d.sync += [
+                    countEqZero.eq(0),
+                    countDiffZero.eq(1),
+                ]
+
         m.d.comb += [
-            executer.start.eq(self.start | (count != 0)),
-            self.done.eq(executer.done & (count == 0)),
+            executer.start.eq(self.start | countDiffZero),
+            self.done.eq(executer.done & countEqZero),
         ]
 
+        if platform == "formal":
+            m.d.comb += [
+                Assert(countEqZero == (count == 0)),
+                Assert(countDiffZero == (count != 0)),
+            ]
+
         return m
 
 # RefreshTimer -------------------------------------------------------------------------------------
@@ -135,9 +160,12 @@ class RefreshTimer(Elaboratable):
     """
 
     def __init__(self, trefi):
+        # TODO: we don't pass formal verification for trefi = 1
+        assert trefi != 1
+
         self.wait = Signal()
         self.done = Signal()
-        self.count = Signal(bits_for(trefi))
+        self.count = Signal(range(trefi), reset=trefi-1)
         self._trefi = trefi
 
     def elaborate(self, platform):
@@ -145,19 +173,19 @@ class RefreshTimer(Elaboratable):
 
         trefi = self._trefi
 
-        done = Signal()
-        count = Signal(bits_for(trefi), reset=trefi-1)
+        with m.If(self.wait & (self.count != 0)):
+            m.d.sync += self.count.eq(self.count-1)
 
-        with m.If(self.wait & ~self.done):
-            m.d.sync += count.eq(count-1)
+            with m.If(self.count == 1):
+                m.d.sync += self.done.eq(1)
         with m.Else():
-            m.d.sync += count.eq(count.reset)
+            m.d.sync += [
+                self.count.eq(self.count.reset),
+                self.done.eq(0),
+            ]
 
-        m.d.comb += [
-            done.eq(count == 0),
-            self.done.eq(done),
-            self.count.eq(count)
-        ]
+        if platform == "formal":
+            m.d.comb += Assert(self.done == (self.count == 0))
 
         return m
 
@@ -178,9 +206,8 @@ class RefreshPostponer(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        count = Signal(bits_for(self._postponing), reset=self._postponing-1)
+        count = Signal(range(self._postponing), reset=self._postponing-1)
 
-        m.d.sync += self.req_o.eq(0)
         with m.If(self.req_i):
             with m.If(count == 0):
                 m.d.sync += [
@@ -188,7 +215,12 @@ class RefreshPostponer(Elaboratable):
                     self.req_o.eq(1),
                 ]
             with m.Else():
-                m.d.sync += count.eq(count-1)
+                m.d.sync += [
+                    count.eq(count-1),
+                    self.req_o.eq(0),
+                ]
+        with m.Else():
+            m.d.sync += self.req_o.eq(0)
 
         return m
 
@@ -205,14 +237,14 @@ class ZQCSExecuter(Elaboratable):
     - Wait tZQCS
     """
 
-    def __init__(self, trp, tzqcs):
+    def __init__(self, abits, babits, trp, tzqcs):
         self.start = Signal()
         self.done = Signal()
         self._trp = trp
         self._tzqcs = tzqcs
 
-        self.a = Signal()
-        self.ba = Signal()
+        self.a = Signal(abits)
+        self.ba = Signal(babits)
         self.cas = Signal()
         self.ras = Signal()
         self.we = Signal()
@@ -223,7 +255,7 @@ class ZQCSExecuter(Elaboratable):
         trp = self._trp
         tzqcs = self._tzqcs
 
-        tl = Timeline([
+        m.submodules.timeline = tl = Timeline([
             # Precharge All
             (0, [
                 self.a.eq(2**10),
@@ -252,7 +284,6 @@ class ZQCSExecuter(Elaboratable):
                 self.done.eq(1)
             ]),
         ])
-        m.submodules += tl
         m.d.comb += tl.trigger.eq(self.start)
 
         return m
@@ -277,10 +308,10 @@ class Refresher(Elaboratable):
 
     def __init__(self, settings, clk_freq, zqcs_freq=1e0, postponing=1):
         assert postponing <= 8
-        abits = settings.geom.addressbits
-        babits = settings.geom.bankbits + log2_int(settings.phy.nranks)
+        self._abits = settings.geom.addressbits
+        self._babits = settings.geom.bankbits + log2_int(settings.phy.nranks)
         self.cmd = cmd = stream.Endpoint(
-            cmd_request_rw_layout(a=abits, ba=babits))
+            cmd_request_rw_layout(a=self._abits, ba=self._babits))
         self._postponing = postponing
         self._settings = settings
         self._clk_freq = clk_freq
@@ -290,7 +321,6 @@ class Refresher(Elaboratable):
         m = Module()
 
         wants_refresh = Signal()
-        wants_zqcs = Signal()
 
         settings = self._settings
 
@@ -308,19 +338,17 @@ class Refresher(Elaboratable):
         ]
 
         # Refresh Sequencer ------------------------------------------------------------------------
-        sequencer = RefreshSequencer(
-            settings.timing.tRP, settings.timing.tRFC, self._postponing)
+        sequencer = RefreshSequencer(self._abits, self._babits, settings.timing.tRP, settings.timing.tRFC, self._postponing)
         m.submodules.sequencer = sequencer
 
         if settings.timing.tZQCS is not None:
+
             # ZQCS Timer ---------------------------------------------------------------------------
             zqcs_timer = RefreshTimer(int(self._clk_freq/self._zqcs_freq))
             m.submodules.zqcs_timer = zqcs_timer
-            m.d.comb += wants_zqcs.eq(zqcs_timer.done)
 
             # ZQCS Executer ------------------------------------------------------------------------
-            zqcs_executer = ZQCSExecuter(
-                settings.timing.tRP, settings.timing.tZQCS)
+            zqcs_executer = ZQCSExecuter(self._abits, self._babits, settings.timing.tRP, settings.timing.tZQCS)
             m.submodules.zqs_executer = zqcs_executer
             m.d.comb += zqcs_timer.wait.eq(~zqcs_executer.done)
 
@@ -338,48 +366,44 @@ class Refresher(Elaboratable):
 
             if settings.timing.tZQCS is None:
                 with m.State("Do-Refresh"):
-                    m.d.comb += [
-                        self.cmd.valid.eq(1),
-                        self.cmd.a.eq(sequencer.a),
-                        self.cmd.ba.eq(sequencer.ba),
-                        self.cmd.cas.eq(sequencer.cas),
-                        self.cmd.ras.eq(sequencer.ras),
-                        self.cmd.we.eq(sequencer.we),
-                    ]
+                    m.d.comb += self.cmd.valid.eq(~sequencer.done)
                     with m.If(sequencer.done):
-                        m.d.comb += [
-                            self.cmd.valid.eq(0),
-                            self.cmd.last.eq(1),
-                        ]
+                        m.d.comb += self.cmd.last.eq(1)
                         m.next = "Idle"
             else:
                 with m.State("Do-Refresh"):
-                    m.d.comb += self.cmd.valid.eq(1)
+                    m.d.comb += self.cmd.valid.eq(zqcs_timer.done & ~sequencer.done)
                     with m.If(sequencer.done):
-                        with m.If(wants_zqcs):
+                        with m.If(zqcs_timer.done):
                             m.d.comb += zqcs_executer.start.eq(1)
                             m.next = "Do-Zqcs"
                         with m.Else():
-                            m.d.comb += [
-                                self.cmd.valid.eq(0),
-                                self.cmd.last.eq(1),
-                            ]
+                            m.d.comb += self.cmd.last.eq(1)
                             m.next = "Idle"
 
                 with m.State("Do-Zqcs"):
-                    m.d.comb += [
-                        self.cmd.valid.eq(1),
-                        self.cmd.a.eq(zqcs_executer.a),
-                        self.cmd.ba.eq(zqcs_executer.ba),
-                        self.cmd.cas.eq(zqcs_executer.cas),
-                        self.cmd.ras.eq(zqcs_executer.ras),
-                        self.cmd.we.eq(zqcs_executer.we),
-                    ]
+                    m.d.comb += self.cmd.valid.eq(~zqcs_executer.done)
                     with m.If(zqcs_executer.done):
-                        m.d.comb += [
-                            self.cmd.valid.eq(0),
-                            self.cmd.last.eq(1),
-                        ]
+                        m.d.comb += self.cmd.last.eq(1)
                         m.next = "Idle"
 
+        # Connect sequencer/executer outputs to cmd
+        if settings.timing.tZQCS is None:
+            m.d.comb += [
+                self.cmd.a.eq(sequencer.a),
+                self.cmd.ba.eq(sequencer.ba),
+                self.cmd.cas.eq(sequencer.cas),
+                self.cmd.ras.eq(sequencer.ras),
+                self.cmd.we.eq(sequencer.we),
+            ]
+        else:
+            m.d.comb += [
+                self.cmd.a.eq(zqcs_executer.a),
+                self.cmd.ba.eq(zqcs_executer.ba),
+                self.cmd.cas.eq(zqcs_executer.cas),
+                self.cmd.ras.eq(zqcs_executer.ras),
+                self.cmd.we.eq(zqcs_executer.we),
+            ]
+
+
         return m