microudp: fix compilation warning
[litex.git] / milkymist / asmicon / bankmachine.py
index 187a904d41f6c0595b02a990986c85f42987bc71..361631de698cf1ada98ed4e05d24ebee5f19f9e2 100644 (file)
@@ -1,8 +1,8 @@
-from migen.fhdl.structure import *
+from migen.fhdl.std import *
 from migen.bus.asmibus import *
-from migen.corelogic.roundrobin import *
-from migen.corelogic.fsm import FSM
-from migen.corelogic.misc import optree
+from migen.genlib.roundrobin import *
+from migen.genlib.fsm import FSM
+from migen.genlib.misc import optree
 
 from milkymist.asmicon.multiplexer import *
 
@@ -31,99 +31,83 @@ class _AddressSlicer:
                if isinstance(address, int):
                        return (address & (2**self._b1 - 1)) << self.address_align
                else:
-                       return Cat(Constant(0, BV(self.address_align)), address[:self._b1])
+                       return Cat(Replicate(0, self.address_align), address[:self._b1])
 
-class _Selector:
+class _Selector(Module):
        def __init__(self, slicer, bankn, slots):
-               self.slicer = slicer
-               self.bankn = bankn
-               self.slots = slots
-               
-               self.nslots = len(self.slots)
+               nslots = len(slots)
                self.stb = Signal()
                self.ack = Signal()
-               self.tag = Signal(BV(bits_for(self.nslots-1)))
-               self.adr = Signal(self.slots[0].adr.bv)
+               self.tag = Signal(max=nslots)
+               self.adr = Signal(slots[0].adr.nbits)
                self.we = Signal()
                
                # derived classes should drive rr.request
-               self.rr = RoundRobin(self.nslots, SP_CE)
+               self.submodules.rr = RoundRobin(nslots, SP_CE)
        
-       def get_fragment(self):
-               comb = []
-               rr = self.rr
-               
+               ###
+
                # Multiplex
-               state = Signal(BV(2))
-               comb += [
-                       state.eq(Array(slot.state for slot in self.slots)[rr.grant]),
-                       self.adr.eq(Array(slot.adr for slot in self.slots)[rr.grant]),
-                       self.we.eq(Array(slot.we for slot in self.slots)[rr.grant]),
+               rr = self.rr
+               state = Signal(2)
+               self.comb += [
+                       state.eq(Array(slot.state for slot in slots)[rr.grant]),
+                       self.adr.eq(Array(slot.adr for slot in slots)[rr.grant]),
+                       self.we.eq(Array(slot.we for slot in slots)[rr.grant]),
                        self.stb.eq(
-                               (self.slicer.bank(self.adr) == self.bankn) \
+                               (slicer.bank(self.adr) == bankn) \
                                & (state == SLOT_PENDING)),
-                       rr.ce.eq(self.ack),
+                       rr.ce.eq(self.ack | ~self.stb),
                        self.tag.eq(rr.grant)
                ]
-               comb += [If((rr.grant == i) & self.stb & self.ack, slot.process.eq(1))
-                       for i, slot in enumerate(self.slots)]
-                       
-               return Fragment(comb) + rr.get_fragment()
+               self.comb += [If((rr.grant == i) & self.stb & self.ack, slot.process.eq(1))
+                       for i, slot in enumerate(slots)]
+
+               self.complete_selector(slicer, bankn, slots)
 
 class _SimpleSelector(_Selector):
-       def get_fragment(self):
-               comb = []
-               for i, slot in enumerate(self.slots):
-                       comb.append(self.rr.request[i].eq(
-                               (self.slicer.bank(slot.adr) == self.bankn) & \
-                               (slot.state == SLOT_PENDING)
-                       ))
-       
-               return Fragment(comb) + super().get_fragment()
+       def complete_selector(self, slicer, bankn, slots):
+               for i, slot in enumerate(slots):
+                       self.comb += self.rr.request[i].eq(
+                               (slicer.bank(slot.adr) == bankn) & \
+                               (slot.state == SLOT_PENDING))
 
 class _FullSelector(_Selector):
-       def get_fragment(self):
-               comb = []
-               sync = []
+       def complete_selector(self, slicer, bankn, slots):
                rr = self.rr
 
                # List outstanding requests for our bank
                outstandings = []
-               for slot in self.slots:
+               for slot in slots:
                        outstanding = Signal()
-                       comb.append(outstanding.eq(
-                               (self.slicer.bank(slot.adr) == self.bankn) & \
-                               (slot.state == SLOT_PENDING)
-                       ))
+                       self.comb += outstanding.eq(
+                               (slicer.bank(slot.adr) == bankn) & \
+                               (slot.state == SLOT_PENDING))
                        outstandings.append(outstanding)
                
                # Row tracking
-               openrow_r = Signal(BV(self.slicer.geom_settings.row_a))
-               openrow_n = Signal(BV(self.slicer.geom_settings.row_a))
-               openrow = Signal(BV(self.slicer.geom_settings.row_a))
-               comb += [
-                       openrow_n.eq(self.slicer.row(self.adr)),
+               openrow_r = Signal(slicer.geom_settings.row_a)
+               openrow_n = Signal(slicer.geom_settings.row_a)
+               openrow = Signal(slicer.geom_settings.row_a)
+               self.comb += [
+                       openrow_n.eq(slicer.row(self.adr)),
                        If(self.stb,
                                openrow.eq(openrow_n)
                        ).Else(
                                openrow.eq(openrow_r)
                        )
                ]
-               sync += [
-                       If(self.stb & self.ack,
-                               openrow_r.eq(openrow_n)
-                       )
-               ]
+               self.sync += If(self.stb & self.ack, openrow_r.eq(openrow_n))
                hits = []
-               for slot, os in zip(self.slots, outstandings):
+               for slot, os in zip(slots, outstandings):
                        hit = Signal()
-                       comb.append(hit.eq((self.slicer.row(slot.adr) == openrow) & os))
+                       self.comb += hit.eq((slicer.row(slot.adr) == openrow) & os)
                        hits.append(hit)
                
                # Determine best request
                rr = RoundRobin(self.nslots, SP_CE)
                has_hit = Signal()
-               comb.append(has_hit.eq(optree("|", hits)))
+               self.comb += has_hit.eq(optree("|", hits))
                
                best_hit = [rr.request[i].eq(hit)
                        for i, hit in enumerate(hits)]
@@ -135,84 +119,72 @@ class _FullSelector(_Selector):
                                *best_fallback
                        )
                
-               if self.slots[0].time:
+               if slots[0].time:
                        # Implement anti-starvation timer
                        matures = []
-                       for slot, os in zip(self.slots, outstandings):
+                       for slot, os in zip(slots, outstandings):
                                mature = Signal()
                                comb.append(mature.eq(slot.mature & os))
                                matures.append(mature)
                        has_mature = Signal()
-                       comb.append(has_mature.eq(optree("|", matures)))
+                       self.comb += has_mature.eq(optree("|", matures))
                        best_mature = [rr.request[i].eq(mature)
                                for i, mature in enumerate(matures)]
                        select_stmt = If(has_mature, *best_mature).Else(select_stmt)
-               comb.append(select_stmt)
-               
-               return Fragment(comb, sync) + super().get_fragment()
+               self.comb += select_stmt
 
-class _Buffer:
+class _Buffer(Module):
        def __init__(self, source):
-               self.source = source
-               
                self.stb = Signal()
                self.ack = Signal()
-               self.tag = Signal(self.source.tag.bv)
-               self.adr = Signal(self.source.adr.bv)
+               self.tag = Signal(source.tag.bv)
+               self.adr = Signal(source.adr.bv)
                self.we = Signal()
        
-       def get_fragment(self):
+               ###
+
                en = Signal()
-               comb = [
+               self.comb += [
                        en.eq(self.ack | ~self.stb),
-                       self.source.ack.eq(en)
+                       source.ack.eq(en)
                ]
-               sync = [
+               self.sync += [
                        If(en,
-                               self.stb.eq(self.source.stb),
-                               self.tag.eq(self.source.tag),
-                               self.adr.eq(self.source.adr),
-                               self.we.eq(self.source.we)
+                               self.stb.eq(source.stb),
+                               self.tag.eq(source.tag),
+                               self.adr.eq(source.adr),
+                               self.we.eq(source.we)
                        )
                ]
-               return Fragment(comb, sync)
        
-class BankMachine:
+class BankMachine(Module):
        def __init__(self, geom_settings, timing_settings, address_align, bankn, slots, full_selector):
-               self.geom_settings = geom_settings
-               self.timing_settings = timing_settings
-               self.address_align = address_align
-               self.bankn = bankn
-               self.slots = slots
-               self.full_selector = full_selector
-               
                self.refresh_req = Signal()
                self.refresh_gnt = Signal()
                self.cmd = CommandRequestRW(geom_settings.mux_a, geom_settings.bank_a,
                        bits_for(len(slots)-1))
 
-       def get_fragment(self):
-               comb = []
-               sync = []
-               
+               ###
+
                # Sub components
-               slicer = _AddressSlicer(self.geom_settings, self.address_align)
-               if self.full_selector:
-                       selector = _FullSelector(slicer, self.bankn, self.slots)
-                       buf = _Buffer(selector)
-                       cmdsource = buf
+               slicer = _AddressSlicer(geom_settings, address_align)
+               if full_selector:
+                       selector = _FullSelector(slicer, bankn, slots)
+                       self.submodules.buf = _Buffer(selector)
+                       cmdsource = self.buf
                else:
-                       selector = _SimpleSelector(slicer, self.bankn, self.slots)
+                       selector = _SimpleSelector(slicer, bankn, slots)
                        cmdsource = selector
+               self.submodules += selector
                
                # Row tracking
                has_openrow = Signal()
-               openrow = Signal(BV(self.geom_settings.row_a))
+               openrow = Signal(geom_settings.row_a)
                hit = Signal()
-               comb.append(hit.eq(openrow == slicer.row(cmdsource.adr)))
+               self.comb += hit.eq(openrow == slicer.row(cmdsource.adr))
                track_open = Signal()
                track_close = Signal()
-               sync += [
+               self.sync += [
                        If(track_open,
                                has_openrow.eq(1),
                                openrow.eq(slicer.row(cmdsource.adr))
@@ -224,8 +196,8 @@ class BankMachine:
                
                # Address generation
                s_row_adr = Signal()
-               comb += [
-                       self.cmd.ba.eq(self.bankn),
+               self.comb += [
+                       self.cmd.ba.eq(bankn),
                        If(s_row_adr,
                                self.cmd.a.eq(slicer.row(cmdsource.adr))
                        ).Else(
@@ -233,19 +205,34 @@ class BankMachine:
                        )
                ]
                
-               comb.append(self.cmd.tag.eq(cmdsource.tag))
+               self.comb += self.cmd.tag.eq(cmdsource.tag)
+               
+               # Respect write-to-precharge specification
+               precharge_ok = Signal()
+               t_unsafe_precharge = 2 + timing_settings.tWR - 1
+               unsafe_precharge_count = Signal(max=t_unsafe_precharge+1)
+               self.comb += precharge_ok.eq(unsafe_precharge_count == 0)
+               self.sync += [
+                       If(self.cmd.stb & self.cmd.ack & self.cmd.is_write,
+                               unsafe_precharge_count.eq(t_unsafe_precharge)
+                       ).Elif(~precharge_ok,
+                               unsafe_precharge_count.eq(unsafe_precharge_count-1)
+                       )
+               ]
                
                # Control and command generation FSM
                fsm = FSM("REGULAR", "PRECHARGE", "ACTIVATE", "REFRESH", delayed_enters=[
-                       ("TRP", "ACTIVATE", self.timing_settings.tRP-1),
-                       ("TRCD", "REGULAR", self.timing_settings.tRCD-1)
+                       ("TRP", "ACTIVATE", timing_settings.tRP-1),
+                       ("TRCD", "REGULAR", timing_settings.tRCD-1)
                ])
+               self.submodules += fsm
                fsm.act(fsm.REGULAR,
                        If(self.refresh_req,
                                fsm.next_state(fsm.REFRESH)
                        ).Elif(cmdsource.stb,
                                If(has_openrow,
                                        If(hit,
+                                               # NB: write-to-read specification is enforced by multiplexer
                                                self.cmd.stb.eq(1),
                                                cmdsource.ack.eq(self.cmd.ack),
                                                self.cmd.is_read.eq(~cmdsource.we),
@@ -265,10 +252,12 @@ class BankMachine:
                        # 1. we are presenting the column address, A10 is always low
                        # 2. since we always go to the ACTIVATE state, we do not need
                        # to assert track_close.
-                       self.cmd.stb.eq(1),
-                       If(self.cmd.ack, fsm.next_state(fsm.TRP)),
-                       self.cmd.ras_n.eq(0),
-                       self.cmd.we_n.eq(0)
+                       If(precharge_ok,
+                               self.cmd.stb.eq(1),
+                               If(self.cmd.ack, fsm.next_state(fsm.TRP)),
+                               self.cmd.ras_n.eq(0),
+                               self.cmd.we_n.eq(0)
+                       )
                )
                fsm.act(fsm.ACTIVATE,
                        s_row_adr.eq(1),
@@ -278,16 +267,7 @@ class BankMachine:
                        self.cmd.ras_n.eq(0)
                )
                fsm.act(fsm.REFRESH,
-                       self.refresh_gnt.eq(1),
+                       self.refresh_gnt.eq(precharge_ok),
                        track_close.eq(1),
                        If(~self.refresh_req, fsm.next_state(fsm.REGULAR))
                )
-               
-               if self.full_selector:
-                       buf_fragment = buf.get_fragment()
-               else:
-                       buf_fragment = Fragment()
-               return Fragment(comb, sync) + \
-                       selector.get_fragment() + \
-                       buf_fragment + \
-                       fsm.get_fragment()