memory_share: Improve same-address merging, recognize wide write ports.
authorMarcelina Kościelnicka <mwk@0x04.net>
Sun, 25 Oct 2020 22:01:59 +0000 (23:01 +0100)
committerMarcelina Kościelnicka <mwk@0x04.net>
Thu, 27 May 2021 13:53:12 +0000 (15:53 +0200)
passes/memory/memory_share.cc

index 98637720c91423de21d8cf5c7ea8f92a2ad46f7a..bfd94a344abc6b06d2f9d8cd0d91c052cb92100a 100644 (file)
@@ -33,230 +33,100 @@ struct MemoryShareWorker
        SigMap sigmap, sigmap_xmux;
        ModWalker modwalker;
        CellTypes cone_ct;
+       bool flag_widen;
 
 
        // ------------------------------------------------------
        // Consolidate write ports that write to the same address
+       // (or close enough to be merged to wide ports)
        // ------------------------------------------------------
 
-       RTLIL::SigSpec mask_en_naive(RTLIL::SigSpec do_mask, RTLIL::SigSpec bits, RTLIL::SigSpec mask_bits)
-       {
-               // this is the naive version of the function that does not care about grouping the EN bits.
-
-               RTLIL::SigSpec inv_mask_bits = module->Not(NEW_ID, mask_bits);
-               RTLIL::SigSpec inv_mask_bits_filtered = module->Mux(NEW_ID, RTLIL::SigSpec(RTLIL::State::S1, bits.size()), inv_mask_bits, do_mask);
-               RTLIL::SigSpec result = module->And(NEW_ID, inv_mask_bits_filtered, bits);
-               return result;
-       }
-
-       RTLIL::SigSpec mask_en_grouped(RTLIL::SigSpec do_mask, RTLIL::SigSpec bits, RTLIL::SigSpec mask_bits)
-       {
-               // this version of the function preserves the bit grouping in the EN bits.
-
-               std::vector<RTLIL::SigBit> v_bits = bits;
-               std::vector<RTLIL::SigBit> v_mask_bits = mask_bits;
-
-               std::map<std::pair<RTLIL::SigBit, RTLIL::SigBit>, std::pair<int, std::vector<int>>> groups;
-               RTLIL::SigSpec grouped_bits, grouped_mask_bits;
-
-               for (int i = 0; i < bits.size(); i++) {
-                       std::pair<RTLIL::SigBit, RTLIL::SigBit> key(v_bits[i], v_mask_bits[i]);
-                       if (groups.count(key) == 0) {
-                               groups[key].first = grouped_bits.size();
-                               grouped_bits.append(v_bits[i]);
-                               grouped_mask_bits.append(v_mask_bits[i]);
-                       }
-                       groups[key].second.push_back(i);
-               }
-
-               std::vector<RTLIL::SigBit> grouped_result = mask_en_naive(do_mask, grouped_bits, grouped_mask_bits);
-               RTLIL::SigSpec result;
-
-               for (int i = 0; i < bits.size(); i++) {
-                       std::pair<RTLIL::SigBit, RTLIL::SigBit> key(v_bits[i], v_mask_bits[i]);
-                       result.append(grouped_result.at(groups.at(key).first));
-               }
-
-               return result;
-       }
-
-       void merge_en_data(RTLIL::SigSpec &merged_en, RTLIL::SigSpec &merged_data, RTLIL::SigSpec next_en, RTLIL::SigSpec next_data)
-       {
-               std::vector<RTLIL::SigBit> v_old_en = merged_en;
-               std::vector<RTLIL::SigBit> v_next_en = next_en;
-
-               // The new merged_en signal is just the old merged_en signal and next_en OR'ed together.
-               // But of course we need to preserve the bit grouping..
-
-               std::map<std::pair<RTLIL::SigBit, RTLIL::SigBit>, int> groups;
-               std::vector<RTLIL::SigBit> grouped_old_en, grouped_next_en;
-               RTLIL::SigSpec new_merged_en;
-
-               for (int i = 0; i < int(v_old_en.size()); i++) {
-                       std::pair<RTLIL::SigBit, RTLIL::SigBit> key(v_old_en[i], v_next_en[i]);
-                       if (groups.count(key) == 0) {
-                               groups[key] = grouped_old_en.size();
-                               grouped_old_en.push_back(key.first);
-                               grouped_next_en.push_back(key.second);
-                       }
-               }
-
-               std::vector<RTLIL::SigBit> grouped_new_en = module->Or(NEW_ID, grouped_old_en, grouped_next_en);
-
-               for (int i = 0; i < int(v_old_en.size()); i++) {
-                       std::pair<RTLIL::SigBit, RTLIL::SigBit> key(v_old_en[i], v_next_en[i]);
-                       new_merged_en.append(grouped_new_en.at(groups.at(key)));
-               }
-
-               // Create the new merged_data signal.
-
-               RTLIL::SigSpec new_merged_data(RTLIL::State::Sx, merged_data.size());
-
-               RTLIL::SigSpec old_data_set = module->And(NEW_ID, merged_en, merged_data);
-               RTLIL::SigSpec old_data_unset = module->And(NEW_ID, merged_en, module->Not(NEW_ID, merged_data));
-
-               RTLIL::SigSpec new_data_set = module->And(NEW_ID, next_en, next_data);
-               RTLIL::SigSpec new_data_unset = module->And(NEW_ID, next_en, module->Not(NEW_ID, next_data));
-
-               new_merged_data = module->Or(NEW_ID, new_merged_data, old_data_set);
-               new_merged_data = module->And(NEW_ID, new_merged_data, module->Not(NEW_ID, old_data_unset));
-
-               new_merged_data = module->Or(NEW_ID, new_merged_data, new_data_set);
-               new_merged_data = module->And(NEW_ID, new_merged_data, module->Not(NEW_ID, new_data_unset));
-
-               // Update merged_* signals
-
-               merged_en = new_merged_en;
-               merged_data = new_merged_data;
-       }
-
-       void consolidate_wr_by_addr(Mem &mem)
+       bool consolidate_wr_by_addr(Mem &mem)
        {
                if (GetSize(mem.wr_ports) <= 1)
-                       return;
+                       return false;
 
                log("Consolidating write ports of memory %s.%s by address:\n", log_id(module), log_id(mem.memid));
 
-               std::map<RTLIL::SigSpec, int> last_port_by_addr;
-               std::vector<std::vector<bool>> active_bits_on_port;
-
-               bool cache_clk_enable = false;
-               bool cache_clk_polarity = false;
-               RTLIL::SigSpec cache_clk;
-               int cache_wide_log2 = 0;
-
                bool changed = false;
-
                for (int i = 0; i < GetSize(mem.wr_ports); i++)
                {
-                       auto &port = mem.wr_ports[i];
-                       RTLIL::SigSpec addr = sigmap_xmux(port.addr);
-
-                       if (port.clk_enable != cache_clk_enable ||
-                                       port.wide_log2 != cache_wide_log2 ||
-                                       (cache_clk_enable && (sigmap(port.clk) != cache_clk ||
-                                       port.clk_polarity != cache_clk_polarity)))
-                       {
-                               cache_clk_enable = port.clk_enable;
-                               cache_clk_polarity = port.clk_polarity;
-                               cache_clk = sigmap(port.clk);
-                               cache_wide_log2 = port.wide_log2;
-                               last_port_by_addr.clear();
-
-                               if (cache_clk_enable)
-                                       log("  New clock domain: %s %s\n", cache_clk_polarity ? "posedge" : "negedge", log_signal(cache_clk));
-                               else
-                                       log("  New clock domain: unclocked\n");
-                       }
-
-                       log("    Port %d has addr %s.\n", i, log_signal(addr));
-
-                       log("      Active bits: ");
-                       std::vector<RTLIL::SigBit> en_bits = sigmap(port.en);
-                       active_bits_on_port.push_back(std::vector<bool>(en_bits.size()));
-                       for (int k = int(en_bits.size())-1; k >= 0; k--) {
-                               active_bits_on_port[i][k] = en_bits[k].wire != NULL || en_bits[k].data != RTLIL::State::S0;
-                               log("%c", active_bits_on_port[i][k] ? '1' : '0');
-                       }
-                       log("\n");
-
-                       if (last_port_by_addr.count(addr))
+                       auto &port1 = mem.wr_ports[i];
+                       if (port1.removed)
+                               continue;
+                       if (!port1.clk_enable)
+                               continue;
+                       for (int j = i + 1; j < GetSize(mem.wr_ports); j++)
                        {
-                               int last_i = last_port_by_addr.at(addr);
-                               log("      Merging port %d into this one.\n", last_i);
-
-                               bool found_overlapping_bits = false;
-                               for (int k = 0; k < int(en_bits.size()); k++) {
-                                       if (active_bits_on_port[i][k] && active_bits_on_port[last_i][k])
-                                               found_overlapping_bits = true;
-                                       active_bits_on_port[i][k] = active_bits_on_port[i][k] || active_bits_on_port[last_i][k];
-                               }
-
-                               // Force this ports addr input to addr directly (skip don't care muxes)
-
-                               port.addr = addr;
-
-                               // If any of the ports between `last_i' and `i' write to the same address, this
-                               // will have priority over whatever `last_i` wrote. So we need to revisit those
-                               // ports and mask the EN bits accordingly.
-
-                               RTLIL::SigSpec merged_en = sigmap(mem.wr_ports[last_i].en);
-
-                               for (int j = last_i+1; j < i; j++)
-                               {
-                                       if (mem.wr_ports[j].removed)
+                               auto &port2 = mem.wr_ports[j];
+                               if (port2.removed)
+                                       continue;
+                               if (!port2.clk_enable)
+                                       continue;
+                               if (port1.clk != port2.clk)
+                                       continue;
+                               if (port1.clk_polarity != port2.clk_polarity)
+                                       continue;
+                               // If the width of the ports doesn't match, they can still be
+                               // merged by widening the narrow one.  Check if the conditions
+                               // hold for that.
+                               int wide_log2 = std::max(port1.wide_log2, port2.wide_log2);
+                               if (GetSize(port1.addr) <= wide_log2)
+                                       continue;
+                               if (GetSize(port2.addr) <= wide_log2)
+                                       continue;
+                               if (!port1.addr.extract(0, wide_log2).is_fully_const())
+                                       continue;
+                               if (!port2.addr.extract(0, wide_log2).is_fully_const())
+                                       continue;
+                               if (sigmap_xmux(port1.addr.extract_end(wide_log2)) != sigmap_xmux(port2.addr.extract_end(wide_log2))) {
+                                       // Incompatible addresses after widening.  Last chance — widen both
+                                       // ports by one more bit to merge them.
+                                       if (!flag_widen)
+                                               continue;
+                                       wide_log2++;
+                                       if (sigmap_xmux(port1.addr.extract_end(wide_log2)) != sigmap_xmux(port2.addr.extract_end(wide_log2)))
+                                               continue;
+                                       if (!port1.addr.extract(0, wide_log2).is_fully_const())
+                                               continue;
+                                       if (!port2.addr.extract(0, wide_log2).is_fully_const())
                                                continue;
-
-                                       for (int k = 0; k < int(en_bits.size()); k++)
-                                               if (active_bits_on_port[i][k] && active_bits_on_port[j][k])
-                                                       goto found_overlapping_bits_i_j;
-
-                                       if (0) {
-                               found_overlapping_bits_i_j:
-                                               log("      Creating collosion-detect logic for port %d.\n", j);
-                                               RTLIL::SigSpec is_same_addr = module->addWire(NEW_ID);
-                                               module->addEq(NEW_ID, addr, mem.wr_ports[j].addr, is_same_addr);
-                                               merged_en = mask_en_grouped(is_same_addr, merged_en, sigmap(mem.wr_ports[j].en));
-                                       }
                                }
-
-                               // Then we need to merge the (masked) EN and the DATA signals.
-
-                               RTLIL::SigSpec merged_data = mem.wr_ports[last_i].data;
-                               if (found_overlapping_bits) {
-                                       log("      Creating logic for merging DATA and EN ports.\n");
-                                       merge_en_data(merged_en, merged_data, sigmap(port.en), sigmap(port.data));
-                               } else {
-                                       RTLIL::SigSpec cell_en = sigmap(port.en);
-                                       RTLIL::SigSpec cell_data = sigmap(port.data);
-                                       for (int k = 0; k < int(en_bits.size()); k++)
-                                               if (!active_bits_on_port[last_i][k]) {
-                                                       merged_en.replace(k, cell_en.extract(k, 1));
-                                                       merged_data.replace(k, cell_data.extract(k, 1));
-                                               }
+                               log("  Merging ports %d, %d (address %s).\n", i, j, log_signal(port1.addr));
+                               mem.prepare_wr_merge(i, j);
+                               port1.addr = sigmap_xmux(port1.addr);
+                               port2.addr = sigmap_xmux(port2.addr);
+                               mem.widen_wr_port(i, wide_log2);
+                               mem.widen_wr_port(j, wide_log2);
+                               int pos = 0;
+                               while (pos < GetSize(port1.data)) {
+                                       int epos = pos;
+                                       while (epos < GetSize(port1.data) && port1.en[epos] == port1.en[pos] && port2.en[epos] == port2.en[pos])
+                                               epos++;
+                                       int width = epos - pos;
+                                       SigBit new_en;
+                                       if (port2.en[pos] == State::S0) {
+                                               new_en = port1.en[pos];
+                                       } else if (port1.en[pos] == State::S0) {
+                                               port1.data.replace(pos, port2.data.extract(pos, width));
+                                               new_en = port2.en[pos];
+                                       } else {
+                                               port1.data.replace(pos, module->Mux(NEW_ID, port1.data.extract(pos, width), port2.data.extract(pos, width), port2.en[pos]));
+                                               new_en = module->Or(NEW_ID, port1.en[pos], port2.en[pos]);
+                                       }
+                                       for (int k = pos; k < epos; k++)
+                                               port1.en[k] = new_en;
+                                       pos = epos;
                                }
-
-                               // Connect the new EN and DATA signals and remove the old write port.
-
-                               port.en = merged_en;
-                               port.data = merged_data;
-
-                               mem.wr_ports[last_i].removed = true;
                                changed = true;
-
-                               log("      Active bits: ");
-                               std::vector<RTLIL::SigBit> en_bits = sigmap(port.en);
-                               active_bits_on_port.push_back(std::vector<bool>(en_bits.size()));
-                               for (int k = int(en_bits.size())-1; k >= 0; k--)
-                                       log("%c", active_bits_on_port[i][k] ? '1' : '0');
-                               log("\n");
+                               port2.removed = true;
                        }
-
-                       last_port_by_addr[addr] = i;
                }
 
                if (changed)
                        mem.emit();
+
+               return changed;
        }
 
 
@@ -441,7 +311,7 @@ struct MemoryShareWorker
        // Setup and run
        // -------------
 
-       MemoryShareWorker(RTLIL::Design *design) : design(design), modwalker(design) {}
+       MemoryShareWorker(RTLIL::Design *design, bool flag_widen) : design(design), modwalker(design), flag_widen(flag_widen) {}
 
        void operator()(RTLIL::Module* module)
        {
@@ -465,8 +335,9 @@ struct MemoryShareWorker
                        }
                }
 
-               for (auto &mem : memories)
-                       consolidate_wr_by_addr(mem);
+               for (auto &mem : memories) {
+                       while (consolidate_wr_by_addr(mem));
+               }
 
                cone_ct.setup_internals();
                cone_ct.cell_types.erase(ID($mul));
@@ -515,8 +386,10 @@ struct MemorySharePass : public Pass {
        }
        void execute(std::vector<std::string> args, RTLIL::Design *design) override {
                log_header(design, "Executing MEMORY_SHARE pass (consolidating $memrd/$memwr cells).\n");
+               // TODO: expose when wide ports are actually supported.
+               bool flag_widen = false;
                extra_args(args, 1, design);
-               MemoryShareWorker msw(design);
+               MemoryShareWorker msw(design, flag_widen);
 
                for (auto module : design->selected_modules())
                        msw(module);