radix: reading first page table entry
[soc.git] / src / soc / experiment / l0_cache.py
index 12b0890b36fe47af797ccc9aa28908f5cccdd7d9..e2c31096a78269e5d5f55548445524db81f79a71 100644 (file)
@@ -24,11 +24,10 @@ from nmigen.hdl.rec import Record, Layout
 
 from nmutil.latch import SRLatch, latchregister
 from soc.decoder.power_decoder2 import Data
-from soc.decoder.power_enums import InternalOp
+from soc.decoder.power_enums import MicrOp
 from soc.regfile.regfile import ortreereduce
 from nmutil.util import treereduce
 
-from soc.fu.ldst.ldst_input_record import CompLDSTOpSubset
 from soc.decoder.power_decoder2 import Data
 #from nmutil.picker import PriorityPicker
 from nmigen.lib.coding import PriorityEncoder
@@ -36,134 +35,36 @@ from soc.scoreboard.addr_split import LDSTSplitter
 from soc.scoreboard.addr_match import LenExpand
 
 # for testing purposes
-from soc.experiment.testmem import TestMemory
-
+from soc.config.test.test_loadstore import TestMemPspec
+from soc.config.loadstore import ConfigMemoryPortInterface
+from soc.experiment.pimem import PortInterface
+from soc.config.test.test_pi2ls import pi_ld, pi_st, pi_ldst
 import unittest
 
+class L0CacheBuffer2(Elaboratable):
+    """L0CacheBuffer2"""
+    def __init__(self, n_units=8, regwid=64, addrwid=48):
+        self.n_units = n_units
+        self.regwid = regwid
+        self.addrwid = addrwid
+        ul = []
+        for i in range(self.n_units):
+            ul += [PortInterface()]
+        self.dports = Array(ul)
 
-class PortInterface(RecordObject):
-    """PortInterface
-
-    defines the interface - the API - that the LDSTCompUnit connects
-    to.  note that this is NOT a "fire-and-forget" interface.  the
-    LDSTCompUnit *must* be kept appraised that the request is in
-    progress, and only when it has a 100% successful completion rate
-    can the notification be given (busy dropped).
-
-    The interface FSM rules are as follows:
-
-    * if busy_o is asserted, a LD/ST is in progress.  further
-      requests may not be made until busy_o is deasserted.
-
-    * only one of is_ld_i or is_st_i may be asserted.  busy_o
-      will immediately be asserted and remain asserted.
-
-    * addr.ok is to be asserted when the LD/ST address is known.
-      addr.data is to be valid on the same cycle.
-
-      addr.ok and addr.data must REMAIN asserted until busy_o
-      is de-asserted.  this ensures that there is no need
-      for the L0 Cache/Buffer to have an additional address latch
-      (because the LDSTCompUnit already has it)
-
-    * addr_ok_o (or addr_exc_o) must be waited for.  these will
-      be asserted *only* for one cycle and one cycle only.
-
-    * addr_exc_o will be asserted if there is no chance that the
-      memory request may be fulfilled.
-
-      busy_o is deasserted on the same cycle as addr_exc_o is asserted.
-
-    * conversely: addr_ok_o must *ONLY* be asserted if there is a
-      HUNDRED PERCENT guarantee that the memory request will be
-      fulfilled.
-
-    * for a LD, ld.ok will be asserted - for only one clock cycle -
-      at any point in the future that is acceptable to the underlying
-      Memory subsystem.  the recipient MUST latch ld.data on that cycle.
-
-      busy_o is deasserted on the same cycle as ld.ok is asserted.
-
-    * for a ST, st.ok may be asserted only after addr_ok_o had been
-      asserted, alongside valid st.data at the same time.  st.ok
-      must only be asserted for one cycle.
-
-      the underlying Memory is REQUIRED to pick up that data and
-      guarantee its delivery.  no back-acknowledgement is required.
-
-      busy_o is deasserted on the cycle AFTER st.ok is asserted.
-    """
-
-    def __init__(self, name=None, regwid=64, addrwid=48):
-
-        self._regwid = regwid
-        self._addrwid = addrwid
-
-        RecordObject.__init__(self, name=name)
-
-        # distinguish op type (ld/st)
-        self.is_ld_i = Signal(reset_less=True)
-        self.is_st_i = Signal(reset_less=True)
-        self.op = CompLDSTOpSubset()  # hm insn_type ld/st duplicates here
-
-        # common signals
-        self.busy_o = Signal(reset_less=True)     # do not use if busy
-        self.go_die_i = Signal(reset_less=True)   # back to reset
-        self.addr = Data(addrwid, "addr_i")            # addr/addr-ok
-        # addr is valid (TLB, L1 etc.)
-        self.addr_ok_o = Signal(reset_less=True)
-        self.addr_exc_o = Signal(reset_less=True)  # TODO, "type" of exception
-
-        # LD/ST
-        self.ld = Data(regwid, "ld_data_o")  # ok to be set by L0 Cache/Buf
-        self.st = Data(regwid, "st_data_i")  # ok to be set by CompUnit
-
-# TODO: elaborate function
-
-
-class DualPortSplitter(Elaboratable):
-    """DualPortSplitter
-
-    * one incoming PortInterface
-    * two *OUTGOING* PortInterfaces
-    * uses LDSTSplitter to do it
-
-    (actually, thinking about it LDSTSplitter could simply be
-     modified to conform to PortInterface: one in, two out)
-
-    once that is done each pair of ports may be wired directly
-    to the dual ports of L0CacheBuffer
+    def elaborate(self, platform):
+        m = Module()
+        comb, sync = m.d.comb, m.d.sync
 
-    The split is carried out so that, regardless of alignment or
-    mis-alignment, outgoing PortInterface[0] takes bit 4 == 0
-    of the address, whilst outgoing PortInterface[1] takes
-    bit 4 == 1.
+        # connect the ports as modules
 
-    PortInterface *may* need to be changed so that the length is
-    a binary number (accepting values 1-16).
-    """
-    def __init__(self):
-        self.outp = [PortInterface(name="outp_0"),
-                     PortInterface(name="outp_1")]
-        self.inp  = PortInterface(name="inp")
-        print(self.outp)
+        for i in range(self.n_units):
+            d = LDSTSplitter(64, 48, 4, self.dports[i])
+            setattr(m.submodules, "ldst_splitter%d" % i, d)
 
-    def elaborate(self, platform):
-        m = Module()
-        comb = m.d.comb
-        m.submodules.splitter = splitter = LDSTSplitter(64, 48, 4)
-        comb += splitter.addr_i.eq(self.inp.addr) #XXX
-        #comb += splitter.len_i.eq()
-        #comb += splitter.valid_i.eq()
-        comb += splitter.is_ld_i.eq(self.inp.is_ld_i)
-        comb += splitter.is_st_i.eq(self.inp.is_st_i)
-        #comb += splitter.st_data_i.eq()
-        #comb += splitter.sld_valid_i.eq()
-        #comb += splitter.sld_data_i.eq()
-        #comb += splitter.sst_valid_i.eq()
+        # state-machine latches TODO
         return m
 
-
 class DataMergerRecord(Record):
     """
     {data: 128 bit, byte_enable: 16 bit}
@@ -177,6 +78,26 @@ class DataMergerRecord(Record):
         self.data.reset_less = True
         self.en.reset_less = True
 
+class CacheRecord(Record):
+    def __init__(self, name=None):
+        layout = (('addr', 37),
+                  ('a_even', 7),
+                  ('bytemask_even', 16),
+                  ('data_even', 128),
+                  ('a_odd', 7),
+                  ('bytemask_odd', 16),
+                  ('data_odd', 128))
+        Record.__init__(self, Layout(layout), name=name)
+
+        self.addr.reset_less = True
+        self.a_even.reset_less = True
+        self.bytemask_even.reset_less = True
+        self.data_even.reset_less = True
+        self.a_odd.reset_less = True
+        self.bytemask_odd.reset_less = True
+        self.data_odd.reset_less = True
+
+
 
 # TODO: formal verification
 class DataMerger(Elaboratable):
@@ -226,13 +147,13 @@ class DataMerger(Elaboratable):
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
-        #(1) pick a row
+        # (1) pick a row
         m.submodules.pick = pick = PriorityEncoder(self.array_size)
         for j in range(self.array_size):
             comb += pick.i[j].eq(self.addr_array_i[j].bool())
         valid = ~pick.n
         idx = pick.o
-        #(2) merge
+        # (2) merge
         with m.If(valid):
             l = []
             for j in range(self.array_size):
@@ -241,80 +162,52 @@ class DataMerger(Elaboratable):
                 with m.If(select):
                     comb += r.eq(self.data_i[j])
                 l.append(r)
-            comb += self.data_o.data.eq(ortreereduce(l,"data"))
-            comb += self.data_o.en.eq(ortreereduce(l,"en"))
+            comb += self.data_o.data.eq(ortreereduce(l, "data"))
+            comb += self.data_o.en.eq(ortreereduce(l, "en"))
 
         return m
 
+class TstDataMerger2(Elaboratable):
+    def __init__(self):
+        self.data_odd = Signal(128,reset_less=True)
+        self.data_even = Signal(128,reset_less=True)
+        self.n_units = 8
+        ul = []
+        for i in range(self.n_units):
+            ul.append(CacheRecord())
+        self.input_array = Array(ul)
 
-class LDSTPort(Elaboratable):
-    def __init__(self, idx, regwid=64, addrwid=48):
-        self.pi = PortInterface("ldst_port%d" % idx, regwid, addrwid)
+    def addr_match(self,j,addr):
+        ret = []
+        for k in range(self.n_units):
+            ret += [(addr[j] == addr[k])]
+        return Cat(*ret)
 
     def elaborate(self, platform):
         m = Module()
-        comb, sync = m.d.comb, m.d.sync
-
-        # latches
-        m.submodules.busy_l = busy_l = SRLatch(False, name="busy")
-        m.submodules.cyc_l = cyc_l = SRLatch(True, name="cyc")
-        comb += cyc_l.s.eq(0)
-        comb += cyc_l.r.eq(0)
-
-        # this is a little weird: we let the L0Cache/Buffer set
-        # the outputs: this module just monitors "state".
-
-        # LD/ST requested activates "busy"
-        with m.If(self.pi.is_ld_i | self.pi.is_st_i):
-            comb += busy_l.s.eq(1)
-
-        # monitor for an exception or the completion of LD.
-        with m.If(self.pi.addr_exc_o):
-            comb += busy_l.r.eq(1)
-
-        # however ST needs one cycle before busy is reset
-        with m.If(self.pi.st.ok | self.pi.ld.ok):
-            comb += cyc_l.s.eq(1)
-
-        with m.If(cyc_l.q):
-            comb += cyc_l.r.eq(1)
-            comb += busy_l.r.eq(1)
-
-        # busy latch outputs to interface
-        comb += self.pi.busy_o.eq(busy_l.q)
-
+        m.submodules.dm_odd = dm_odd = DataMerger(self.n_units)
+        m.submodules.dm_even = dm_even = DataMerger(self.n_units)
+
+        addr_even = []
+        addr_odd = []
+        for j in range(self.n_units):
+            inp = self.input_array[j]
+            addr_even += [Cat(inp.addr,inp.a_even)]
+            addr_odd +=  [Cat(inp.addr,inp.a_odd)]
+
+        for j in range(self.n_units):
+            inp = self.input_array[j]
+            m.d.comb += dm_even.data_i[j].en.eq(inp.bytemask_even)
+            m.d.comb += dm_odd.data_i[j].en.eq(inp.bytemask_odd)
+            m.d.comb += dm_even.data_i[j].data.eq(inp.data_even)
+            m.d.comb += dm_odd.data_i[j].data.eq(inp.data_odd)
+            m.d.comb += dm_even.addr_array_i[j].eq(self.addr_match(j,addr_even))
+            m.d.comb += dm_odd.addr_array_i[j].eq(self.addr_match(j,addr_odd))
+
+        m.d.comb += self.data_odd.eq(dm_odd.data_o.data)
+        m.d.comb += self.data_even.eq(dm_even.data_o.data)
         return m
 
-    def __iter__(self):
-        yield self.pi.is_ld_i
-        yield self.pi.is_st_i
-        yield from self.pi.op.ports()
-        yield self.pi.busy_o
-        yield self.pi.go_die_i
-        yield from self.pi.addr.ports()
-        yield self.pi.addr_ok_o
-        yield self.pi.addr_exc_o
-
-        yield from self.pi.ld.ports()
-        yield from self.pi.st.ports()
-
-    def ports(self):
-        return list(self)
-
-# TODO: turn this into a module
-def byte_reverse(m, data, length):
-    comb = m.d.comb
-    name = "%s_r" % (data.name)
-    data_r = Signal.like(data, name=name)
-    with m.Switch(length):
-        for j in [1,2,4,8]:
-            with m.Case(j):
-                for i in range(j):
-                    dest = data_r.word_select(i, 8)
-                    src = data.word_select(j-1-i, 8)
-                    comb += dest.eq(src)
-    return data_r
-
 
 class L0CacheBuffer(Elaboratable):
     """L0 Cache / Buffer
@@ -330,194 +223,116 @@ class L0CacheBuffer(Elaboratable):
     a "demo" / "test" class, and one important aspect: it responds
     combinatorially, where a nmigen FSM's state-changes only activate
     on clock-sync boundaries.
+
+    Note: the data byte-order is *not* expected to be normalised (LE/BE)
+    by this class.  That task is taken care of by LDSTCompUnit.
     """
 
-    def __init__(self, n_units, mem, regwid=64, addrwid=48):
+    def __init__(self, n_units, pimem, regwid=64, addrwid=48):
         self.n_units = n_units
-        self.mem = mem
+        self.pimem = pimem
         self.regwid = regwid
         self.addrwid = addrwid
         ul = []
         for i in range(n_units):
-            ul.append(LDSTPort(i, regwid, addrwid))
+            ul.append(PortInterface("ldst_port%d" % i, regwid, addrwid))
         self.dports = Array(ul)
 
-    @property
-    def addrbits(self):
-        return log2_int(self.mem.regwid//8)
-
-    def splitaddr(self, addr):
-        """split the address into top and bottom bits of the memory granularity
-        """
-        return addr[:self.addrbits], addr[self.addrbits:]
-
     def elaborate(self, platform):
         m = Module()
         comb, sync = m.d.comb, m.d.sync
 
         # connect the ports as modules
-        for i in range(self.n_units):
-            setattr(m.submodules, "port%d" % i, self.dports[i])
+        for i in range(self.n_units):
+        #    setattr(m.submodules, "port%d" % i, self.dports[i])
 
         # state-machine latches
-        m.submodules.st_active = st_active = SRLatch(False, name="st_active")
-        m.submodules.ld_active = ld_active = SRLatch(False, name="ld_active")
-        m.submodules.reset_l = reset_l = SRLatch(True, name="reset")
         m.submodules.idx_l = idx_l = SRLatch(False, name="idx_l")
-        m.submodules.adrok_l = adrok_l = SRLatch(False, name="addr_acked")
+        m.submodules.reset_l = reset_l = SRLatch(True, name="reset")
 
         # find one LD (or ST) and do it.  only one per cycle.
         # TODO: in the "live" (production) L0Cache/Buffer, merge multiple
         # LD/STs using mask-expansion - see LenExpand class
 
-        m.submodules.ldpick = ldpick = PriorityEncoder(self.n_units)
-        m.submodules.stpick = stpick = PriorityEncoder(self.n_units)
-        m.submodules.lenexp = lenexp = LenExpand(4, 8)
+        m.submodules.pick = pick = PriorityEncoder(self.n_units)
 
-        lds = Signal(self.n_units, reset_less=True)
-        sts = Signal(self.n_units, reset_less=True)
-        ldi = []
-        sti = []
+        ldsti = []
         for i in range(self.n_units):
-            pi = self.dports[i].pi
-            ldi.append(pi.is_ld_i & pi.busy_o)  # accumulate ld-req signals
-            sti.append(pi.is_st_i & pi.busy_o)  # accumulate st-req signals
-        # put the requests into the priority-pickers
-        comb += ldpick.i.eq(Cat(*ldi))
-        comb += stpick.i.eq(Cat(*sti))
+            pi = self.dports[i]
+            busy = (pi.is_ld_i | pi.is_st_i)  # & pi.busy_o
+            ldsti.append(busy)  # accumulate ld/st-req
+        # put the requests into the priority-picker
+        comb += pick.i.eq(Cat(*ldsti))
 
         # hmm, have to select (record) the right port index
         nbits = log2_int(self.n_units, False)
-        ld_idx = Signal(nbits, reset_less=False)
-        st_idx = Signal(nbits, reset_less=False)
+        idx = Signal(nbits, reset_less=False)
+
         # use these because of the sync-and-comb pass-through capability
-        latchregister(m, ldpick.o, ld_idx, idx_l.qn, name="ld_idx_l")
-        latchregister(m, stpick.o, st_idx, idx_l.qn, name="st_idx_l")
+        latchregister(m, pick.o, idx, idx_l.q, name="idx_l")
 
         # convenience variables to reference the "picked" port
-        ldport = self.dports[ld_idx].pi
-        stport = self.dports[st_idx].pi
-        # and the memory ports
-        rdport = self.mem.rdport
-        wrport = self.mem.wrport
-
-        # Priority-Pickers pick one and only one request, capture its index.
-        # from that point on this code *only* "listens" to that port.
-
-        sync += adrok_l.s.eq(0)
-        comb += adrok_l.r.eq(0)
-        with m.If(~ldpick.n):
-            comb += ld_active.s.eq(1)  # activate LD mode
-            comb += idx_l.r.eq(1)  # pick (and capture) the port index
-        with m.Elif(~stpick.n):
-            comb += st_active.s.eq(1)  # activate ST mode
-            comb += idx_l.r.eq(1)  # pick (and capture) the port index
+        port = self.dports[idx]
+
+        # pick (and capture) the port index
+        with m.If(~pick.n):
+            comb += idx_l.s.eq(1)
 
         # from this point onwards, with the port "picked", it stays picked
-        # until ld_active (or st_active) are de-asserted.
-
-        # if now in "LD" mode: wait for addr_ok, then send the address out
-        # to memory, acknowledge address, and send out LD data
-        with m.If(ld_active.q):
-            # set up LenExpander with the LD len and lower bits of addr
-            lsbaddr, msbaddr = self.splitaddr(ldport.addr.data)
-            comb += lenexp.len_i.eq(ldport.op.data_len)
-            comb += lenexp.addr_i.eq(lsbaddr)
-            with m.If(ldport.addr.ok & adrok_l.qn):
-                comb += rdport.addr.eq(msbaddr) # addr ok, send thru
-                comb += ldport.addr_ok_o.eq(1)  # acknowledge addr ok
-                sync += adrok_l.s.eq(1)       # and pull "ack" latch
-
-        # if now in "ST" mode: likewise do the same but with "ST"
-        # to memory, acknowledge address, and send out LD data
-        with m.If(st_active.q):
-            # set up LenExpander with the ST len and lower bits of addr
-            lsbaddr, msbaddr = self.splitaddr(stport.addr.data)
-            comb += lenexp.len_i.eq(stport.op.data_len)
-            comb += lenexp.addr_i.eq(lsbaddr)
-            with m.If(stport.addr.ok):
-                comb += wrport.addr.eq(msbaddr)  # addr ok, send thru
-                with m.If(adrok_l.qn):
-                    comb += stport.addr_ok_o.eq(1)  # acknowledge addr ok
-                    sync += adrok_l.s.eq(1)       # and pull "ack" latch
-
-        # NOTE: in both these, below, the port itself takes care
-        # of de-asserting its "busy_o" signal, based on either ld.ok going
-        # high (by us, here) or by st.ok going high (by the LDSTCompUnit).
-
-        # for LD mode, when addr has been "ok'd", assume that (because this
-        # is a "Memory" test-class) the memory read data is valid.
+        # until idx_l is deasserted
         comb += reset_l.s.eq(0)
         comb += reset_l.r.eq(0)
-        with m.If(ld_active.q & adrok_l.q):
-            # shift data down before pushing out.  requires masking
-            # from the *byte*-expanded version of LenExpand output
-            lddata = Signal(self.regwid, reset_less=True)
-            comb += lddata.eq((rdport.data & lenexp.rexp_o) >>
-                              (lenexp.addr_i*8))
-            # byte-reverse the data based on width
-            lddata_r = byte_reverse(m, lddata, lenexp.len_i)
-            comb += ldport.ld.data.eq(lddata_r)  # put data out
-            comb += ldport.ld.ok.eq(1)           # indicate data valid
-            comb += reset_l.s.eq(1)   # reset mode after 1 cycle
-
-        # for ST mode, when addr has been "ok'd", wait for incoming "ST ok"
-        with m.If(st_active.q & stport.st.ok):
-            # shift data up before storing.  lenexp *bit* version of mask is
-            # passed straight through as byte-level "write-enable" lines.
-            # byte-reverse the data based on width
-            stdata_r = byte_reverse(m, stport.st.data, lenexp.len_i)
-            stdata = Signal(self.regwid, reset_less=True)
-            comb += stdata.eq(stdata_r << (lenexp.addr_i*8))
-            comb += wrport.data.eq(stdata)  # write st to mem
-            comb += wrport.en.eq(lenexp.lexp_o) # enable writes
-            comb += reset_l.s.eq(1)   # reset mode after 1 cycle
+
+        with m.If(idx_l.q):
+            comb += self.pimem.connect_port(port)
+            with m.If(~self.pimem.pi.busy_o):
+                comb += reset_l.s.eq(1)  # reset when no longer busy
 
         # ugly hack, due to simultaneous addr req-go acknowledge
         reset_delay = Signal(reset_less=True)
         sync += reset_delay.eq(reset_l.q)
-        with m.If(reset_delay):
-            comb += adrok_l.r.eq(1)     # address reset
 
         # after waiting one cycle (reset_l is "sync" mode), reset the port
         with m.If(reset_l.q):
-            comb += idx_l.s.eq(1)  # deactivate port-index selector
-            comb += ld_active.r.eq(1)   # leave the ST active for 1 cycle
-            comb += st_active.r.eq(1)   # leave the ST active for 1 cycle
+            comb += idx_l.r.eq(1)  # deactivate port-index selector
             comb += reset_l.r.eq(1)     # clear reset
-            comb += adrok_l.r.eq(1)     # address reset
 
         return m
 
-    def ports(self):
+    def __iter__(self):
         for p in self.dports:
             yield from p.ports()
 
+    def ports(self):
+        return list(self)
+
 
 class TstL0CacheBuffer(Elaboratable):
-    def __init__(self, n_units=3, regwid=16, addrwid=4):
-        self.mem = TestMemory(regwid, addrwid, granularity=regwid//8)
-        self.l0 = L0CacheBuffer(n_units, self.mem, regwid, addrwid)
+    def __init__(self, pspec, n_units=3):
+        regwid = pspec.reg_wid
+        addrwid = pspec.addr_wid
+        self.cmpi = ConfigMemoryPortInterface(pspec)
+        self.pimem = self.cmpi.pi
+        self.l0 = L0CacheBuffer(n_units, self.pimem, regwid, addrwid << 1)
 
     def elaborate(self, platform):
         m = Module()
-        m.submodules.mem = self.mem
+        m.submodules.pimem = self.pimem
         m.submodules.l0 = self.l0
+        if hasattr(self.cmpi, 'lsmem'):  # hmmm not happy about this
+            m.submodules.lsmem = self.cmpi.lsmem.lsi
 
         return m
 
     def ports(self):
+        yield from self.cmpi.ports()
         yield from self.l0.ports()
-        yield self.mem.rdport.addr
-        yield self.mem.rdport.data
-        yield self.mem.wrport.addr
-        yield self.mem.wrport.data
-        # TODO: mem ports
+        yield from self.pimem.ports()
 
 
 def wait_busy(port, no=False):
     while True:
-        busy = yield port.pi.busy_o
+        busy = yield port.busy_o
         print("busy", no, busy)
         if bool(busy) == no:
             break
@@ -526,7 +341,7 @@ def wait_busy(port, no=False):
 
 def wait_addr(port):
     while True:
-        addr_ok = yield port.pi.addr_ok_o
+        addr_ok = yield port.addr_ok_o
         print("addrok", addr_ok)
         if not addr_ok:
             break
@@ -535,7 +350,7 @@ def wait_addr(port):
 
 def wait_ldok(port):
     while True:
-        ldok = yield port.pi.ld.ok
+        ldok = yield port.ld.ok
         print("ldok", ldok)
         if ldok:
             break
@@ -543,82 +358,20 @@ def wait_ldok(port):
 
 
 def l0_cache_st(dut, addr, data, datalen):
-    l0 = dut.l0
-    mem = dut.mem
-    port0 = l0.dports[0]
-    port1 = l0.dports[1]
-
-    # have to wait until not busy
-    yield from wait_busy(port1, no=False)    # wait until not busy
-
-    # set up a ST on the port.  address first:
-    yield port1.pi.is_st_i.eq(1)  # indicate ST
-    yield port1.pi.op.data_len.eq(datalen)  # ST length (1/2/4/8)
-
-    yield port1.pi.addr.data.eq(addr)  # set address
-    yield port1.pi.addr.ok.eq(1)  # set ok
-    yield from wait_addr(port1)             # wait until addr ok
-    # yield # not needed, just for checking
-    # yield # not needed, just for checking
-    # assert "ST" for one cycle (required by the API)
-    yield port1.pi.st.data.eq(data)
-    yield port1.pi.st.ok.eq(1)
-    yield
-    yield port1.pi.st.ok.eq(0)
-
-    # can go straight to reset.
-    yield port1.pi.is_st_i.eq(0)  # end
-    yield port1.pi.addr.ok.eq(0)  # set !ok
-    # yield from wait_busy(port1, False)    # wait until not busy
+    return pi_st(dut.l0, addr, datalen)
 
 
 def l0_cache_ld(dut, addr, datalen, expected):
-
-    l0 = dut.l0
-    mem = dut.mem
-    port0 = l0.dports[0]
-    port1 = l0.dports[1]
-
-    # have to wait until not busy
-    yield from wait_busy(port1, no=False)    # wait until not busy
-
-    # set up a LD on the port.  address first:
-    yield port1.pi.is_ld_i.eq(1)  # indicate LD
-    yield port1.pi.op.data_len.eq(datalen)  # LD length (1/2/4/8)
-
-    yield port1.pi.addr.data.eq(addr)  # set address
-    yield port1.pi.addr.ok.eq(1)  # set ok
-    yield from wait_addr(port1)             # wait until addr ok
-
-    yield from wait_ldok(port1)             # wait until ld ok
-    data = yield port1.pi.ld.data
-
-    # cleanup
-    yield port1.pi.is_ld_i.eq(0)  # end
-    yield port1.pi.addr.ok.eq(0)  # set !ok
-    # yield from wait_busy(port1, no=False)    # wait until not busy
-
-    return data
+    return pi_ld(dut.l0, addr, datalen)
 
 
 def l0_cache_ldst(arg, dut):
-    yield
-    addr = 0x2
-    data = 0xbeef
-    data2 = 0xf00f
-    #data = 0x4
-    yield from l0_cache_st(dut, 0x2, data, 2)
-    yield from l0_cache_st(dut, 0x4, data2, 2)
-    result = yield from l0_cache_ld(dut, 0x2, 2, data)
-    result2 = yield from l0_cache_ld(dut, 0x4, 2, data2)
-    yield
-    arg.assertEqual(data, result, "data %x != %x" % (result, data))
-    arg.assertEqual(data2, result2, "data2 %x != %x" % (result2, data2))
+    port0 = dut.l0.dports[0]
+    return pi_ldst(arg, port0)
 
 
 def data_merger_merge(dut):
-    print("data_merger")
-    #starting with all inputs zero
+    # starting with all inputs zero
     yield Settle()
     en = yield dut.data_o.en
     data = yield dut.data_o.data
@@ -638,46 +391,70 @@ def data_merger_merge(dut):
     assert en == 0xff
     yield
 
+def data_merger_test2(dut):
+    # starting with all inputs zero
+    yield Settle()
+    yield
+    yield
+
 
 class TestL0Cache(unittest.TestCase):
 
-    def test_l0_cache(self):
+    def test_l0_cache_test_bare_wb(self):
 
-        dut = TstL0CacheBuffer(regwid=64)
-        #vl = rtlil.convert(dut, ports=dut.ports())
-        #with open("test_basic_l0_cache.il", "w") as f:
-        #    f.write(vl)
+        pspec = TestMemPspec(ldst_ifacetype='test_bare_wb',
+                             addr_wid=48,
+                             mask_wid=8,
+                             reg_wid=64)
+        dut = TstL0CacheBuffer(pspec)
+        vl = rtlil.convert(dut, ports=[])  # TODOdut.ports())
+        with open("test_basic_l0_cache_bare_wb.il", "w") as f:
+            f.write(vl)
 
         run_simulation(dut, l0_cache_ldst(self, dut),
-                       vcd_name='test_l0_cache_basic.vcd')
+                       vcd_name='test_l0_cache_basic_bare_wb.vcd')
+
+    def test_l0_cache_testpi(self):
+
+        pspec = TestMemPspec(ldst_ifacetype='testpi',
+                             addr_wid=48,
+                             mask_wid=8,
+                             reg_wid=64)
+        dut = TstL0CacheBuffer(pspec)
+        vl = rtlil.convert(dut, ports=[])  # TODOdut.ports())
+        with open("test_basic_l0_cache.il", "w") as f:
+            f.write(vl)
+
+        run_simulation(dut, l0_cache_ldst(self, dut),
+                       vcd_name='test_l0_cache_basic_testpi.vcd')
 
 
 class TestDataMerger(unittest.TestCase):
 
     def test_data_merger(self):
 
-        dut = DataMerger(8)
+        dut = TstDataMerger2()
         #vl = rtlil.convert(dut, ports=dut.ports())
-        #with open("test_data_merger.il", "w") as f:
+        # with open("test_data_merger.il", "w") as f:
         #    f.write(vl)
 
-        run_simulation(dut, data_merger_merge(dut),
+        run_simulation(dut, data_merger_test2(dut),
                        vcd_name='test_data_merger.vcd')
 
 
+
 class TestDualPortSplitter(unittest.TestCase):
 
     def test_dual_port_splitter(self):
 
         dut = DualPortSplitter()
         #vl = rtlil.convert(dut, ports=dut.ports())
-        #with open("test_data_merger.il", "w") as f:
+        # with open("test_data_merger.il", "w") as f:
         #    f.write(vl)
 
-        #run_simulation(dut, data_merger_merge(dut),
+        # run_simulation(dut, data_merger_merge(dut),
         #               vcd_name='test_dual_port_splitter.vcd')
 
 
 if __name__ == '__main__':
-    unittest.main(exit=False)
-
+    unittest.main()