from nmigen_soc.wishbone.sram import SRAM
 from nmigen import Memory, Signal, Module
 from soc.minerva.units.loadstore import BareLoadStoreUnit, CachedLoadStoreUnit
+from soc.minerva.units.fetch import BareFetchUnit, CachedFetchUnit
 
 
 class TestSRAMBareLoadStoreUnit(BareLoadStoreUnit):
         comb += sram.bus.adr.eq(dbus.adr)
 
         return m
+
+
+class TestSRAMBareFetchUnit(BareFetchUnit):
+    def __init__(self, addr_wid=64, data_wid=64):
+        super().__init__(addr_wid, data_wid)
+        # small 16-entry Memory
+        self.mem = Memory(width=self.data_wid, depth=32)
+
+    def _get_memory(self):
+        return self.mem
+
+    def elaborate(self, platform):
+        m = super().elaborate(platform)
+        comb = m.d.comb
+        m.submodules.sram = sram = SRAM(memory=self.mem, read_only=True,
+                                        features={'cti', 'bte', 'err'})
+        ibus = self.ibus
+
+        # directly connect the wishbone bus of FetchUnitInterface to SRAM
+        # note: SRAM is a target (slave), ibus is initiator (master)
+        fanouts = ['dat_w', 'sel', 'cyc', 'stb', 'we', 'cti', 'bte']
+        fanins = ['dat_r', 'ack', 'err']
+        for fanout in fanouts:
+            print ("fanout", fanout, getattr(sram.bus, fanout).shape(),
+                                     getattr(ibus, fanout).shape())
+            comb += getattr(sram.bus, fanout).eq(getattr(ibus, fanout))
+            comb += getattr(sram.bus, fanout).eq(getattr(ibus, fanout))
+        for fanin in fanins:
+            comb += getattr(ibus, fanin).eq(getattr(sram.bus, fanin))
+        # connect address
+        comb += sram.bus.adr.eq(ibus.adr)
+
+        return m
 
 of unnecessarily-duplicated code
 """
 from soc.experiment.imem import TestMemFetchUnit
-#from soc.bus.test.test_minerva import TestSRAMBareFetchUnit
+from soc.bus.test.test_minerva import TestSRAMBareFetchUnit
 
 
 class ConfigFetchUnit:
     def __init__(self, pspec):
         fudict = {'testmem': TestMemFetchUnit,
-                   #'test_bare_wb': TestSRAMBareFetchUnit,
+                   'test_bare_wb': TestSRAMBareFetchUnit,
                    #'test_cache_wb': TestCacheFetchUnit
                   }
         fukls = fudict[pspec.imem_ifacetype]
 
 def read_from_addr(dut, addr):
     yield dut.a_pc_i.eq(addr)
     yield dut.a_valid_i.eq(1)
+    yield dut.f_valid_i.eq(1)
     yield dut.a_stall_i.eq(1)
     yield
     yield dut.a_stall_i.eq(0)
     yield Settle()
     while (yield dut.f_busy_o):
         yield
-    assert (yield dut.a_valid_i)
-    return (yield dut.f_instr_o)
+    res = (yield dut.f_instr_o)
+
+    yield dut.a_valid_i.eq(0)
+    yield dut.f_valid_i.eq(0)
+    yield
+    return res
 
 
 def tst_lsmemtype(ifacetype):
     sim = Simulator(m)
     sim.add_clock(1e-6)
 
-    mem = dut.mem.mem
+    mem = dut._get_memory()
 
     def process():
 
         values = [random.randint(0, (1<<32)-1) for x in range(16)]
         for addr, val in enumerate(values):
             yield mem._array[addr].eq(val)
+        yield Settle()
 
         for addr, val in enumerate(values):
             x = yield from read_from_addr(dut, addr << 2)
         sim.run()
 
 if __name__ == '__main__':
-    #tst_lsmemtype('test_bare_wb')
+    tst_lsmemtype('test_bare_wb')
     tst_lsmemtype('testmem')
 
         # limit TestMemory to 2^6 entries of regwid size
         self.mem = TestMemory(self.data_wid, 6, readonly=True)
 
+    def _get_memory(self):
+        return self.mem.mem
+
     def elaborate(self, platform):
         m = Module()
         regwid, addrwid = self.data_wid, self.addr_wid
 
         m.d.comb += iba.bus.connect(self.ibus)
 
         icache_port = iba.port(priority=0)
-        cti = Mux(icache.bus_last, Cycle.END, Cycle.INCREMENT
+        cti = Mux(icache.bus_last, Cycle.END, Cycle.INCREMENT)
         m.d.comb += [
             icache_port.cyc.eq(icache.bus_re),
             icache_port.stb.eq(icache.bus_re),
 
         go_insn_i = Signal()
         pc_i = Signal(32)
 
-        m.submodules.issuer = issuer = TestIssuer(ifacetype="test_bare_wb")
-        imem = issuer.imem.mem.mem
+        m.submodules.issuer = issuer = TestIssuer(ifacetype="test_bare_wb",
+                                                  imemtype="test_bare_wb")
+        imem = issuer.imem._get_memory()
         core = issuer.core
         pdecode2 = core.pdecode2
         l0 = core.l0