From 2d10caabbc083cdb615ee2035916505758e4944f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Marcelina=20Ko=C5=9Bcielnicka?= Date: Mon, 26 Oct 2020 03:20:57 +0100 Subject: [PATCH] memory_share: Improve sat-based port sharing. --- passes/memory/memory_share.cc | 268 +++++++++++++++++++--------------- 1 file changed, 151 insertions(+), 117 deletions(-) diff --git a/passes/memory/memory_share.cc b/passes/memory/memory_share.cc index bfd94a344..19afeb72d 100644 --- a/passes/memory/memory_share.cc +++ b/passes/memory/memory_share.cc @@ -139,13 +139,9 @@ struct MemoryShareWorker if (GetSize(mem.wr_ports) <= 1) return; - ezSatPtr ez; - SatGen satgen(ez.get(), &modwalker.sigmap); + // Get a list of ports that have any chance of being mergeable. - // find list of considered ports and port pairs - - std::set considered_ports; - std::set considered_port_pairs; + pool eligible_ports; for (int i = 0; i < GetSize(mem.wr_ports); i++) { auto &port = mem.wr_ports[i]; @@ -154,152 +150,190 @@ struct MemoryShareWorker if (bit == RTLIL::State::S1) goto port_is_always_active; if (modwalker.has_drivers(bits)) - considered_ports.insert(i); + eligible_ports.insert(i); port_is_always_active:; } + if (eligible_ports.size() <= 1) + return; + log("Consolidating write ports of memory %s.%s using sat-based resource sharing:\n", log_id(module), log_id(mem.memid)); - bool cache_clk_enable = false; - bool cache_clk_polarity = false; - RTLIL::SigSpec cache_clk; - int cache_wide_log2 = 0; + // Group eligible ports by clock domain and width. + pool checked_ports; + std::vector> groups; for (int i = 0; i < GetSize(mem.wr_ports); i++) { - auto &port = mem.wr_ports[i]; + auto &port1 = mem.wr_ports[i]; + if (!eligible_ports.count(i)) + continue; + if (checked_ports.count(i)) + continue; - if (port.clk_enable != cache_clk_enable || - port.wide_log2 != cache_wide_log2 || - (cache_clk_enable && (sigmap(port.clk) != cache_clk || - port.clk_polarity != cache_clk_polarity))) + + std::vector group; + group.push_back(i); + + for (int j = i + 1; j < GetSize(mem.wr_ports); j++) { - cache_clk_enable = port.clk_enable; - cache_clk_polarity = port.clk_polarity; - cache_clk = sigmap(port.clk); - cache_wide_log2 = port.wide_log2; + auto &port2 = mem.wr_ports[j]; + if (!eligible_ports.count(j)) + continue; + if (checked_ports.count(j)) + continue; + if (port1.clk_enable != port2.clk_enable) + continue; + if (port1.clk_enable) { + if (port1.clk != port2.clk) + continue; + if (port1.clk_polarity != port2.clk_polarity) + continue; + } + if (port1.wide_log2 != port2.wide_log2) + continue; + group.push_back(j); } - else if (i > 0 && considered_ports.count(i-1) && considered_ports.count(i)) - considered_port_pairs.insert(i); - - if (cache_clk_enable) - log(" Port %d on %s %s: %s\n", i, - cache_clk_polarity ? "posedge" : "negedge", log_signal(cache_clk), - considered_ports.count(i) ? "considered" : "not considered"); - else - log(" Port %d unclocked: %s\n", i, - considered_ports.count(i) ? "considered" : "not considered"); - } - if (considered_port_pairs.size() < 1) { - log(" No two subsequent ports in same clock domain considered -> nothing to consolidate.\n"); - return; - } + for (auto j : group) + checked_ports.insert(j); - // create SAT representation of common input cone of all considered EN signals - - pool one_hot_wires; - std::set sat_cells; - std::set bits_queue; - std::map port_to_sat_variable; + if (group.size() <= 1) + continue; - for (int i = 0; i < GetSize(mem.wr_ports); i++) - if (considered_port_pairs.count(i) || considered_port_pairs.count(i+1)) - { - RTLIL::SigSpec sig = modwalker.sigmap(mem.wr_ports[i].en); - port_to_sat_variable[i] = ez->expression(ez->OpOr, satgen.importSigSpec(sig)); + groups.push_back(group); + } - std::vector bits = sig; - bits_queue.insert(bits.begin(), bits.end()); + bool changed = false; + for (auto &group : groups) { + auto &some_port = mem.wr_ports[group[0]]; + string ports; + for (auto idx : group) { + if (idx != group[0]) + ports += ", "; + ports += std::to_string(idx); + } + if (!some_port.clk_enable) { + log(" Checking unclocked group, width %d: ports %s.\n", mem.width << some_port.wide_log2, ports.c_str()); + } else { + log(" Checking group clocked with %sedge %s, width %d: ports %s.\n", some_port.clk_polarity ? "pos" : "neg", log_signal(some_port.clk), mem.width << some_port.wide_log2, ports.c_str()); } - while (!bits_queue.empty()) - { - for (auto bit : bits_queue) - if (bit.wire && bit.wire->get_bool_attribute(ID::onehot)) - one_hot_wires.insert(bit.wire); - - pool portbits; - modwalker.get_drivers(portbits, bits_queue); - bits_queue.clear(); - - for (auto &pbit : portbits) - if (sat_cells.count(pbit.cell) == 0 && cone_ct.cell_known(pbit.cell->type)) { - pool &cell_inputs = modwalker.cell_inputs[pbit.cell]; - bits_queue.insert(cell_inputs.begin(), cell_inputs.end()); - sat_cells.insert(pbit.cell); - } - } + // Okay, time to actually run the SAT solver. - for (auto wire : one_hot_wires) { - log(" Adding one-hot constraint for wire %s.\n", log_id(wire)); - vector ez_wire_bits = satgen.importSigSpec(wire); - for (int i : ez_wire_bits) - for (int j : ez_wire_bits) - if (i != j) ez->assume(ez->NOT(i), j); - } + ezSatPtr ez; + SatGen satgen(ez.get(), &modwalker.sigmap); - log(" Common input cone for all EN signals: %d cells.\n", int(sat_cells.size())); + // create SAT representation of common input cone of all considered EN signals - for (auto cell : sat_cells) - satgen.importCell(cell); + pool one_hot_wires; + std::set sat_cells; + std::set bits_queue; + dict port_to_sat_variable; - log(" Size of unconstrained SAT problem: %d variables, %d clauses\n", ez->numCnfVariables(), ez->numCnfClauses()); + for (auto idx : group) { + RTLIL::SigSpec sig = modwalker.sigmap(mem.wr_ports[idx].en); + port_to_sat_variable[idx] = ez->expression(ez->OpOr, satgen.importSigSpec(sig)); - // merge subsequent ports if possible + std::vector bits = sig; + bits_queue.insert(bits.begin(), bits.end()); + } - bool changed = false; - for (int i = 0; i < GetSize(mem.wr_ports); i++) - { - if (!considered_port_pairs.count(i)) - continue; + while (!bits_queue.empty()) + { + for (auto bit : bits_queue) + if (bit.wire && bit.wire->get_bool_attribute(ID::onehot)) + one_hot_wires.insert(bit.wire); + + pool portbits; + modwalker.get_drivers(portbits, bits_queue); + bits_queue.clear(); + + for (auto &pbit : portbits) + if (sat_cells.count(pbit.cell) == 0 && cone_ct.cell_known(pbit.cell->type)) { + pool &cell_inputs = modwalker.cell_inputs[pbit.cell]; + bits_queue.insert(cell_inputs.begin(), cell_inputs.end()); + sat_cells.insert(pbit.cell); + } + } - if (ez->solve(port_to_sat_variable.at(i-1), port_to_sat_variable.at(i))) { - log(" According to SAT solver sharing of port %d with port %d is not possible.\n", i-1, i); - continue; + for (auto wire : one_hot_wires) { + log(" Adding one-hot constraint for wire %s.\n", log_id(wire)); + vector ez_wire_bits = satgen.importSigSpec(wire); + for (int i : ez_wire_bits) + for (int j : ez_wire_bits) + if (i != j) ez->assume(ez->NOT(i), j); } - log(" Merging port %d into port %d.\n", i-1, i); - port_to_sat_variable.at(i) = ez->OR(port_to_sat_variable.at(i-1), port_to_sat_variable.at(i)); + log(" Common input cone for all EN signals: %d cells.\n", int(sat_cells.size())); - RTLIL::SigSpec last_addr = mem.wr_ports[i-1].addr; - RTLIL::SigSpec last_data = mem.wr_ports[i-1].data; - std::vector last_en = modwalker.sigmap(mem.wr_ports[i-1].en); + for (auto cell : sat_cells) + satgen.importCell(cell); - RTLIL::SigSpec this_addr = mem.wr_ports[i].addr; - RTLIL::SigSpec this_data = mem.wr_ports[i].data; - std::vector this_en = modwalker.sigmap(mem.wr_ports[i].en); + log(" Size of unconstrained SAT problem: %d variables, %d clauses\n", ez->numCnfVariables(), ez->numCnfClauses()); - RTLIL::SigBit this_en_active = module->ReduceOr(NEW_ID, this_en); + // now try merging the ports. - if (GetSize(last_addr) < GetSize(this_addr)) - last_addr.extend_u0(GetSize(this_addr)); - else - this_addr.extend_u0(GetSize(last_addr)); + for (int ii = 0; ii < GetSize(group); ii++) { + int idx1 = group[ii]; + auto &port1 = mem.wr_ports[idx1]; + if (port1.removed) + continue; + for (int jj = ii + 1; jj < GetSize(group); jj++) { + int idx2 = group[jj]; + auto &port2 = mem.wr_ports[idx2]; + if (port2.removed) + continue; - mem.wr_ports[i].addr = module->Mux(NEW_ID, last_addr, this_addr, this_en_active); - mem.wr_ports[i].data = module->Mux(NEW_ID, last_data, this_data, this_en_active); + if (ez->solve(port_to_sat_variable.at(idx1), port_to_sat_variable.at(idx2))) { + log(" According to SAT solver sharing of port %d with port %d is not possible.\n", idx1, idx2); + continue; + } + + log(" Merging port %d into port %d.\n", idx2, idx1); + mem.prepare_wr_merge(idx1, idx2); + port_to_sat_variable.at(idx1) = ez->OR(port_to_sat_variable.at(idx1), port_to_sat_variable.at(idx2)); + + RTLIL::SigSpec last_addr = port1.addr; + RTLIL::SigSpec last_data = port1.data; + std::vector last_en = modwalker.sigmap(port1.en); + + RTLIL::SigSpec this_addr = port2.addr; + RTLIL::SigSpec this_data = port2.data; + std::vector this_en = modwalker.sigmap(port2.en); + + RTLIL::SigBit this_en_active = module->ReduceOr(NEW_ID, this_en); + + if (GetSize(last_addr) < GetSize(this_addr)) + last_addr.extend_u0(GetSize(this_addr)); + else + this_addr.extend_u0(GetSize(last_addr)); + + port1.addr = module->Mux(NEW_ID, last_addr, this_addr, this_en_active); + port1.data = module->Mux(NEW_ID, last_data, this_data, this_en_active); + + std::map, int> groups_en; + RTLIL::SigSpec grouped_last_en, grouped_this_en, en; + RTLIL::Wire *grouped_en = module->addWire(NEW_ID, 0); + + for (int j = 0; j < int(this_en.size()); j++) { + std::pair key(last_en[j], this_en[j]); + if (!groups_en.count(key)) { + grouped_last_en.append(last_en[j]); + grouped_this_en.append(this_en[j]); + groups_en[key] = grouped_en->width; + grouped_en->width++; + } + en.append(RTLIL::SigSpec(grouped_en, groups_en[key])); + } - std::map, int> groups_en; - RTLIL::SigSpec grouped_last_en, grouped_this_en, en; - RTLIL::Wire *grouped_en = module->addWire(NEW_ID, 0); + module->addMux(NEW_ID, grouped_last_en, grouped_this_en, this_en_active, grouped_en); + port1.en = en; - for (int j = 0; j < int(this_en.size()); j++) { - std::pair key(last_en[j], this_en[j]); - if (!groups_en.count(key)) { - grouped_last_en.append(last_en[j]); - grouped_this_en.append(this_en[j]); - groups_en[key] = grouped_en->width; - grouped_en->width++; + port2.removed = true; + changed = true; } - en.append(RTLIL::SigSpec(grouped_en, groups_en[key])); } - - module->addMux(NEW_ID, grouped_last_en, grouped_this_en, this_en_active, grouped_en); - mem.wr_ports[i].en = en; - - mem.wr_ports[i-1].removed = true; - changed = true; } if (changed) -- 2.30.2