From: Clifford Wolf Date: Sat, 20 Apr 2019 15:52:16 +0000 (+0200) Subject: Add "onehot" pass, improve "pmux2shiftx" onehot handling X-Git-Tag: yosys-0.9~186^2~1 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=97e9caa4fa1f874b693a9d948f48418f22babb6c;p=yosys.git Add "onehot" pass, improve "pmux2shiftx" onehot handling Signed-off-by: Clifford Wolf --- diff --git a/passes/opt/pmux2shiftx.cc b/passes/opt/pmux2shiftx.cc index 4cd061c68..5fd49a571 100644 --- a/passes/opt/pmux2shiftx.cc +++ b/passes/opt/pmux2shiftx.cc @@ -23,6 +23,172 @@ USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN +struct OnehotDatabase +{ + Module *module; + const SigMap &sigmap; + bool verbose = false; + + pool init_ones; + dict> sig_sources_db; + dict sig_onehot_cache; + pool recursion_guard; + + OnehotDatabase(Module *module, const SigMap &sigmap) : module(module), sigmap(sigmap) + { + } + + void initialize() + { + for (auto wire : module->wires()) + { + auto it = wire->attributes.find("\\init"); + if (it == wire->attributes.end()) + continue; + + auto &val = it->second; + int width = std::max(GetSize(wire), GetSize(val)); + + for (int i = 0; i < width; i++) + if (val[i] == State::S1) + init_ones.insert(sigmap(SigBit(wire, i))); + } + + for (auto cell : module->cells()) + { + vector inputs; + SigSpec output; + + if (cell->type.in("$adff", "$dff", "$dffe", "$dlatch", "$ff")) + { + output = cell->getPort("\\Q"); + if (cell->type == "$adff") + inputs.push_back(cell->getParam("\\ARST_VALUE")); + inputs.push_back(cell->getPort("\\D")); + } + + if (cell->type.in("$mux", "$pmux")) + { + output = cell->getPort("\\Y"); + inputs.push_back(cell->getPort("\\A")); + SigSpec B = cell->getPort("\\B"); + for (int i = 0; i < GetSize(B); i += GetSize(output)) + inputs.push_back(B.extract(i, GetSize(output))); + } + + if (!output.empty()) + { + output = sigmap(output); + auto &srcs = sig_sources_db[output]; + for (auto src : inputs) { + while (!src.empty() && src[GetSize(src)-1] == State::S0) + src.remove(GetSize(src)-1); + srcs.insert(sigmap(src)); + } + } + } + } + + void query_worker(const SigSpec &sig, bool &retval, bool &cache, int indent) + { + if (verbose) + log("%*s %s\n", indent, "", log_signal(sig)); + log_assert(retval); + + if (recursion_guard.count(sig)) { + if (verbose) + log("%*s - recursion\n", indent, ""); + cache = false; + return; + } + + auto it = sig_onehot_cache.find(sig); + if (it != sig_onehot_cache.end()) { + if (verbose) + log("%*s - cached (%s)\n", indent, "", it->second ? "true" : "false"); + if (!it->second) + retval = false; + return; + } + + bool found_init_ones = false; + for (auto bit : sig) { + if (init_ones.count(bit)) { + if (found_init_ones) { + if (verbose) + log("%*s - non-onehot init value\n", indent, ""); + retval = false; + break; + } + found_init_ones = true; + } + } + + if (retval) + { + if (sig.is_fully_const()) + { + bool found_ones = false; + for (auto bit : sig) { + if (bit == State::S1) { + if (found_ones) { + if (verbose) + log("%*s - non-onehot constant\n", indent, ""); + retval = false; + break; + } + found_ones = true; + } + } + } + else + { + auto srcs = sig_sources_db.find(sig); + if (srcs == sig_sources_db.end()) { + if (verbose) + log("%*s - no sources for non-const signal\n", indent, ""); + retval = false; + } else { + for (auto &src : srcs->second) { + bool child_cache = true; + recursion_guard.insert(sig); + query_worker(src, retval, child_cache, indent+4); + recursion_guard.erase(sig); + if (!child_cache) + cache = false; + if (!retval) + break; + } + } + } + } + + // it is always safe to cache a negative result + if (cache || !retval) + sig_onehot_cache[sig] = retval; + } + + bool query(const SigSpec &sig) + { + bool retval = true; + bool cache = true; + + if (verbose) + log("** ONEHOT QUERY START (%s)\n", log_signal(sig)); + + query_worker(sig, retval, cache, 3); + + if (verbose) + log("** ONEHOT QUERY RESULT = %s\n", retval ? "true" : "false"); + + // it is always safe to cache the root result of a query + if (!cache) + sig_onehot_cache[sig] = retval; + + return retval; + } +}; + struct Pmux2ShiftxPass : public Pass { Pmux2ShiftxPass() : Pass("pmux2shiftx", "transform $pmux cells to $shiftx cells") { } void help() YS_OVERRIDE @@ -33,6 +199,9 @@ struct Pmux2ShiftxPass : public Pass { log("\n"); log("This pass transforms $pmux cells to $shiftx cells.\n"); log("\n"); + log(" -v, -vv\n"); + log(" verbose output\n"); + log("\n"); log(" -min_density \n"); log(" specifies the minimum density for the shifter\n"); log(" default: 50\n"); @@ -41,9 +210,9 @@ struct Pmux2ShiftxPass : public Pass { log(" specified the minimum number of choices for a control signal\n"); log(" default: 3\n"); log("\n"); - log(" -allow_onehot\n"); - log(" by default, pmuxes with one-hot encoded control signals are not\n"); - log(" converted. this option disables that check.\n"); + log(" -onehot ignore|pmux|shiftx\n"); + log(" select strategy for one-hot encoded control signals\n"); + log(" default: pmux\n"); log("\n"); } void execute(std::vector args, RTLIL::Design *design) YS_OVERRIDE @@ -51,6 +220,9 @@ struct Pmux2ShiftxPass : public Pass { int min_density = 50; int min_choices = 3; bool allow_onehot = false; + bool optimize_onehot = true; + bool verbose = false; + bool verbose_onehot = false; log_header(design, "Executing PMUX2SHIFTX pass.\n"); @@ -64,8 +236,31 @@ struct Pmux2ShiftxPass : public Pass { min_choices = atoi(args[++argidx].c_str()); continue; } - if (args[argidx] == "-allow_onehot") { + if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "ignore") { + argidx++; + allow_onehot = false; + optimize_onehot = false; + continue; + } + if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "pmux") { + argidx++; + allow_onehot = false; + optimize_onehot = true; + continue; + } + if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "shiftx") { + argidx++; allow_onehot = true; + optimize_onehot = false; + continue; + } + if (args[argidx] == "-v") { + verbose = true; + continue; + } + if (args[argidx] == "-vv") { + verbose = true; + verbose_onehot = true; continue; } break; @@ -75,10 +270,15 @@ struct Pmux2ShiftxPass : public Pass { for (auto module : design->selected_modules()) { SigMap sigmap(module); + OnehotDatabase onehot_db(module, sigmap); + onehot_db.verbose = verbose_onehot; + + if (optimize_onehot) + onehot_db.initialize(); dict> eqdb; - for (auto cell : module->selected_cells()) + for (auto cell : module->cells()) { if (cell->type == "$eq") { @@ -181,6 +381,12 @@ struct Pmux2ShiftxPass : public Pass { bool printed_pmux_header = false; + if (verbose) { + printed_pmux_header = true; + log("Inspecting $pmux cell %s/%s.\n", log_id(module), log_id(cell)); + log(" data width: %d (next power-of-2 = %d, log2 = %d)\n", width, extwidth, width_bits); + } + SigSpec updated_S = cell->getPort("\\S"); SigSpec updated_B = cell->getPort("\\B"); @@ -196,7 +402,7 @@ struct Pmux2ShiftxPass : public Pass { } // find the relevant choices - bool is_onehot = true; + bool is_onehot = GetSize(sig) > 2; dict choices; for (int i : seldb.at(sig)) { Const val = eqdb.at(S[i]).second; @@ -211,14 +417,17 @@ struct Pmux2ShiftxPass : public Pass { // TBD: also find choices that are using signals that are subsets of the bits in "sig" - if (is_onehot && !allow_onehot) { - seldb.erase(sig); - continue; - } + if (!verbose) + { + if (is_onehot && !allow_onehot && !optimize_onehot) { + seldb.erase(sig); + continue; + } - if (GetSize(choices) < min_choices) { - seldb.erase(sig); - continue; + if (GetSize(choices) < min_choices) { + seldb.erase(sig); + continue; + } } if (!printed_pmux_header) { @@ -229,6 +438,65 @@ struct Pmux2ShiftxPass : public Pass { log(" checking ctrl signal %s\n", log_signal(sig)); + auto print_choices = [&]() { + log(" table of choices:\n"); + for (auto &it : choices) + log(" %3d: %s: %s\n", it.second, log_signal(it.first), + log_signal(B.extract(it.second*width, width))); + }; + + if (verbose) + { + if (is_onehot && !allow_onehot && !optimize_onehot) { + print_choices(); + log(" ignoring one-hot encoding.\n"); + seldb.erase(sig); + continue; + } + + if (GetSize(choices) < min_choices) { + print_choices(); + log(" insufficient choices.\n"); + seldb.erase(sig); + continue; + } + } + + if (is_onehot && optimize_onehot) + { + print_choices(); + if (!onehot_db.query(sig)) + { + log(" failed to detect onehot driver. do not optimize.\n"); + } + else + { + log(" optimizing one-hot encoding.\n"); + for (auto &it : choices) + { + const Const &val = it.first; + int index = -1; + + for (int i = 0; i < GetSize(val); i++) + if (val[i] == State::S1) { + log_assert(index < 0); + index = i; + } + + if (index < 0) { + log(" %3d: zero encoding.\n", it.second); + continue; + } + + SigBit new_ctrl = sig[index]; + log(" %3d: new crtl signal is %s.\n", it.second, log_signal(new_ctrl)); + updated_S[it.second] = new_ctrl; + } + } + seldb.erase(sig); + continue; + } + // find the best permutation vector perm_new_from_old(GetSize(sig)); Const perm_xormask(State::S0, GetSize(sig)); @@ -434,4 +702,127 @@ struct Pmux2ShiftxPass : public Pass { } } Pmux2ShiftxPass; +struct OnehotPass : public Pass { + OnehotPass() : Pass("onehot", "optimize $eq cells for onehot signals") { } + void help() YS_OVERRIDE + { + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + log("\n"); + log(" onehot [options] [selection]\n"); + log("\n"); + log("This pass optimizes $eq cells that compare one-hot signals against constants\n"); + log("\n"); + log(" -v, -vv\n"); + log(" verbose output\n"); + log("\n"); + } + void execute(std::vector args, RTLIL::Design *design) YS_OVERRIDE + { + bool verbose = false; + bool verbose_onehot = false; + + log_header(design, "Executing ONEHOT pass.\n"); + + size_t argidx; + for (argidx = 1; argidx < args.size(); argidx++) { + if (args[argidx] == "-v") { + verbose = true; + continue; + } + if (args[argidx] == "-vv") { + verbose = true; + verbose_onehot = true; + continue; + } + break; + } + extra_args(args, argidx, design); + + for (auto module : design->selected_modules()) + { + SigMap sigmap(module); + OnehotDatabase onehot_db(module, sigmap); + onehot_db.verbose = verbose_onehot; + onehot_db.initialize(); + + for (auto cell : module->selected_cells()) + { + if (cell->type != "$eq") + continue; + + SigSpec A = sigmap(cell->getPort("\\A")); + SigSpec B = sigmap(cell->getPort("\\B")); + + int a_width = cell->getParam("\\A_WIDTH").as_int(); + int b_width = cell->getParam("\\B_WIDTH").as_int(); + + if (a_width < b_width) { + bool a_signed = cell->getParam("\\A_SIGNED").as_int(); + A.extend_u0(b_width, a_signed); + } + + if (b_width < a_width) { + bool b_signed = cell->getParam("\\B_SIGNED").as_int(); + B.extend_u0(a_width, b_signed); + } + + if (A.is_fully_const()) + std::swap(A, B); + + if (!B.is_fully_const()) + continue; + + if (verbose) + log("Checking $eq(%s, %s) cell %s/%s.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell)); + + if (!onehot_db.query(A)) { + if (verbose) + log(" onehot driver test on %s failed.\n", log_signal(A)); + continue; + } + + int index = -1; + bool not_onehot = false; + + for (int i = 0; i < GetSize(B); i++) { + if (B[i] != State::S1) + continue; + if (index >= 0) + not_onehot = true; + index = i; + } + + if (index < 0) { + if (verbose) + log(" not optimizing the zero pattern.\n"); + continue; + } + + SigSpec Y = cell->getPort("\\Y"); + + if (not_onehot) + { + if (verbose) + log(" replacing with constant 0 driver.\n"); + else + log("Replacing one-hot $eq(%s, %s) cell %s/%s with constant 0 driver.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell)); + module->connect(Y, SigSpec(1, GetSize(Y))); + } + else + { + SigSpec sig = A[index]; + if (verbose) + log(" replacing with signal %s.\n", log_signal(sig)); + else + log("Replacing one-hot $eq(%s, %s) cell %s/%s with signal %s.\n",log_signal(A), log_signal(B), log_id(module), log_id(cell), log_signal(sig)); + sig.extend_u0(GetSize(Y)); + module->connect(Y, sig); + } + + module->remove(cell); + } + } + } +} OnehotPass; + PRIVATE_NAMESPACE_END