# Read a tag from a tag memory row
 def read_tag(way, tagset):
-    return tagset.word_select(way, TAG_WIDTH)[:TAG_BITS]
+    return tagset.word_select(way, TAG_BITS)
 
 # Write a tag to tag memory row
 def write_tag(way, tagset, tag):
     # Cache hit detection, output to fetch2 and other misc logic
     def icache_comb(self, m, use_previous, r, req_index, req_row,
                     req_hit_way, req_tag, real_addr, req_laddr,
-                    cache_tags, cache_valids, access_ok,
+                    cache_valids, access_ok,
                     req_is_hit, req_is_miss, replace_way,
                     plru_victim, cache_out_row):
 
         comb = m.d.comb
+        m.submodules.rd_tag = rd_tag = self.tagmem.read_port(domain="comb")
 
         i_in, i_out, bus = self.i_in, self.i_out, self.bus
         flush_in, stall_out = self.flush_in, self.stall_out
         # i_in.req asserts Decoder active
         cvb = Signal(NUM_WAYS)
         ctag = Signal(TAG_RAM_WIDTH)
-        comb += ctag.eq(cache_tags[req_index])
+        comb += rd_tag.addr.eq(req_index)
+        comb += ctag.eq(rd_tag.data)
         comb += cvb.eq(cache_valids[req_index])
         m.submodules.store_way_e = se = Decoder(NUM_WAYS)
         comb += se.i.eq(r.store_way)
 
     def icache_miss_clr_tag(self, m, r, replace_way,
                             req_index,
-                            cache_tags, cache_valids):
+                            cache_valids):
         comb = m.d.comb
         sync = m.d.sync
+        m.submodules.wr_tag = wr_tag = self.tagmem.write_port(
+                                                    granularity=TAG_BITS)
 
         # Get victim way from plru
         sync += r.store_way.eq(replace_way)
         comb += cv.bit_select(replace_way, 1).eq(0)
         sync += cache_valids[req_index].eq(cv)
 
-        for i in range(NUM_WAYS):
-            with m.If(i == replace_way):
-                tagset = Signal(TAG_RAM_WIDTH)
-                comb += tagset.eq(cache_tags[r.store_index])
-                comb += write_tag(i, tagset, r.store_tag)
-                sync += cache_tags[r.store_index].eq(tagset)
+        # use write-port "granularity" to select the tag to write to
+        # TODO: the Memory should be multipled-up (by NUM_TAGS)
+        tagset = Signal(TAG_RAM_WIDTH)
+        comb += tagset.eq(r.store_tag << (replace_way*TAG_BITS))
+        comb += wr_tag.en.eq(1<<replace_way)
+        comb += wr_tag.addr.eq(r.store_index)
+        comb += wr_tag.data.eq(tagset)
 
         sync += r.state.eq(State.WAIT_ACK)
 
     # Cache miss/reload synchronous machine
     def icache_miss(self, m, r, req_is_miss,
                     req_index, req_laddr, req_tag, replace_way,
-                    cache_tags, cache_valids, access_ok, real_addr):
+                    cache_valids, access_ok, real_addr):
         comb = m.d.comb
         sync = m.d.sync
 
                 with m.If(r.state == State.CLR_TAG):
                     self.icache_miss_clr_tag(m, r, replace_way,
                                              req_index,
-                                             cache_tags, cache_valids)
+                                             cache_valids)
 
                 self.icache_miss_wait_ack(m, r, replace_way, inval_in,
                                           cache_valids, stbs_done)
         comb             = m.d.comb
 
         # Storage. Hopefully "cache_rows" is a BRAM, the rest is LUTs
-        cache_tags       = CacheTagArray()
         cache_valids     = CacheValidsArray()
 
         # TLB Array
         replace_way      = Signal(WAY_BITS)
 
         self.tlbmem = Memory(depth=TLB_SIZE, width=TLB_EA_TAG_BITS+TLB_PTE_BITS)
+        self.tagmem = Memory(depth=NUM_LINES, width=TAG_RAM_WIDTH)
 
         # call sub-functions putting everything together,
         # using shared signals established above
         self.itlb_update(m, itlb, itlb_valid)
         self.icache_comb(m, use_previous, r, req_index, req_row, req_hit_way,
                          req_tag, real_addr, req_laddr,
-                         cache_tags, cache_valids,
+                         cache_valids,
                          access_ok, req_is_hit, req_is_miss,
                          replace_way, plru_victim, cache_out_row)
         self.icache_hit(m, use_previous, r, req_is_hit, req_hit_way,
                         req_index, req_tag, real_addr)
         self.icache_miss(m, r, req_is_miss, req_index,
                          req_laddr, req_tag, replace_way,
-                         cache_tags, cache_valids,
+                         cache_valids,
                          access_ok, real_addr)
         #self.icache_log(m, log_out, req_hit_way, ra_valid, access_ok,
         #                req_is_miss, req_is_hit, lway, wstate, r)