Add "onehot" pass, improve "pmux2shiftx" onehot handling
authorClifford Wolf <clifford@clifford.at>
Sat, 20 Apr 2019 15:52:16 +0000 (17:52 +0200)
committerClifford Wolf <clifford@clifford.at>
Sat, 20 Apr 2019 15:52:16 +0000 (17:52 +0200)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/opt/pmux2shiftx.cc

index 4cd061c68bf4d247232764e2d2c04d75bfe8350c..5fd49a5713308548892ab4c1efcaac3a8fe1ce62 100644 (file)
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
 
+struct OnehotDatabase
+{
+       Module *module;
+       const SigMap &sigmap;
+       bool verbose = false;
+
+       pool<SigBit> init_ones;
+       dict<SigSpec, pool<SigSpec>> sig_sources_db;
+       dict<SigSpec, bool> sig_onehot_cache;
+       pool<SigSpec> recursion_guard;
+
+       OnehotDatabase(Module *module, const SigMap &sigmap) : module(module), sigmap(sigmap)
+       {
+       }
+
+       void initialize()
+       {
+               for (auto wire : module->wires())
+               {
+                       auto it = wire->attributes.find("\\init");
+                       if (it == wire->attributes.end())
+                               continue;
+
+                       auto &val = it->second;
+                       int width = std::max(GetSize(wire), GetSize(val));
+
+                       for (int i = 0; i < width; i++)
+                               if (val[i] == State::S1)
+                                       init_ones.insert(sigmap(SigBit(wire, i)));
+               }
+
+               for (auto cell : module->cells())
+               {
+                       vector<SigSpec> inputs;
+                       SigSpec output;
+
+                       if (cell->type.in("$adff", "$dff", "$dffe", "$dlatch", "$ff"))
+                       {
+                               output = cell->getPort("\\Q");
+                               if (cell->type == "$adff")
+                                       inputs.push_back(cell->getParam("\\ARST_VALUE"));
+                               inputs.push_back(cell->getPort("\\D"));
+                       }
+
+                       if (cell->type.in("$mux", "$pmux"))
+                       {
+                               output = cell->getPort("\\Y");
+                               inputs.push_back(cell->getPort("\\A"));
+                               SigSpec B = cell->getPort("\\B");
+                               for (int i = 0; i < GetSize(B); i += GetSize(output))
+                                       inputs.push_back(B.extract(i, GetSize(output)));
+                       }
+
+                       if (!output.empty())
+                       {
+                               output = sigmap(output);
+                               auto &srcs = sig_sources_db[output];
+                               for (auto src : inputs) {
+                                       while (!src.empty() && src[GetSize(src)-1] == State::S0)
+                                               src.remove(GetSize(src)-1);
+                                       srcs.insert(sigmap(src));
+                               }
+                       }
+               }
+       }
+
+       void query_worker(const SigSpec &sig, bool &retval, bool &cache, int indent)
+       {
+               if (verbose)
+                       log("%*s %s\n", indent, "", log_signal(sig));
+               log_assert(retval);
+
+               if (recursion_guard.count(sig)) {
+                       if (verbose)
+                               log("%*s   - recursion\n", indent, "");
+                       cache = false;
+                       return;
+               }
+
+               auto it = sig_onehot_cache.find(sig);
+               if (it != sig_onehot_cache.end()) {
+                       if (verbose)
+                               log("%*s   - cached (%s)\n", indent, "", it->second ? "true" : "false");
+                       if (!it->second)
+                               retval = false;
+                       return;
+               }
+
+               bool found_init_ones = false;
+               for (auto bit : sig) {
+                       if (init_ones.count(bit)) {
+                               if (found_init_ones) {
+                                       if (verbose)
+                                               log("%*s   - non-onehot init value\n", indent, "");
+                                       retval = false;
+                                       break;
+                               }
+                               found_init_ones = true;
+                       }
+               }
+
+               if (retval)
+               {
+                       if (sig.is_fully_const())
+                       {
+                               bool found_ones = false;
+                               for (auto bit : sig) {
+                                       if (bit == State::S1) {
+                                               if (found_ones) {
+                                                       if (verbose)
+                                                               log("%*s   - non-onehot constant\n", indent, "");
+                                                       retval = false;
+                                                       break;
+                                               }
+                                               found_ones = true;
+                                       }
+                               }
+                       }
+                       else
+                       {
+                               auto srcs = sig_sources_db.find(sig);
+                               if (srcs == sig_sources_db.end()) {
+                                       if (verbose)
+                                               log("%*s   - no sources for non-const signal\n", indent, "");
+                                       retval = false;
+                               } else {
+                                       for (auto &src : srcs->second) {
+                                               bool child_cache = true;
+                                               recursion_guard.insert(sig);
+                                               query_worker(src, retval, child_cache, indent+4);
+                                               recursion_guard.erase(sig);
+                                               if (!child_cache)
+                                                       cache = false;
+                                               if (!retval)
+                                                       break;
+                                       }
+                               }
+                       }
+               }
+
+               // it is always safe to cache a negative result
+               if (cache || !retval)
+                       sig_onehot_cache[sig] = retval;
+       }
+
+       bool query(const SigSpec &sig)
+       {
+               bool retval = true;
+               bool cache = true;
+
+               if (verbose)
+                       log("** ONEHOT QUERY START (%s)\n", log_signal(sig));
+
+               query_worker(sig, retval, cache, 3);
+
+               if (verbose)
+                       log("** ONEHOT QUERY RESULT = %s\n", retval ? "true" : "false");
+
+               // it is always safe to cache the root result of a query
+               if (!cache)
+                       sig_onehot_cache[sig] = retval;
+
+               return retval;
+       }
+};
+
 struct Pmux2ShiftxPass : public Pass {
        Pmux2ShiftxPass() : Pass("pmux2shiftx", "transform $pmux cells to $shiftx cells") { }
        void help() YS_OVERRIDE
@@ -33,6 +199,9 @@ struct Pmux2ShiftxPass : public Pass {
                log("\n");
                log("This pass transforms $pmux cells to $shiftx cells.\n");
                log("\n");
+               log("    -v, -vv\n");
+               log("        verbose output\n");
+               log("\n");
                log("    -min_density <percentage>\n");
                log("        specifies the minimum density for the shifter\n");
                log("        default: 50\n");
@@ -41,9 +210,9 @@ struct Pmux2ShiftxPass : public Pass {
                log("        specified the minimum number of choices for a control signal\n");
                log("        default: 3\n");
                log("\n");
-               log("    -allow_onehot\n");
-               log("        by default, pmuxes with one-hot encoded control signals are not\n");
-               log("        converted. this option disables that check.\n");
+               log("    -onehot ignore|pmux|shiftx\n");
+               log("        select strategy for one-hot encoded control signals\n");
+               log("        default: pmux\n");
                log("\n");
        }
        void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
@@ -51,6 +220,9 @@ struct Pmux2ShiftxPass : public Pass {
                int min_density = 50;
                int min_choices = 3;
                bool allow_onehot = false;
+               bool optimize_onehot = true;
+               bool verbose = false;
+               bool verbose_onehot = false;
 
                log_header(design, "Executing PMUX2SHIFTX pass.\n");
 
@@ -64,8 +236,31 @@ struct Pmux2ShiftxPass : public Pass {
                                min_choices = atoi(args[++argidx].c_str());
                                continue;
                        }
-                       if (args[argidx] == "-allow_onehot") {
+                       if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "ignore") {
+                               argidx++;
+                               allow_onehot = false;
+                               optimize_onehot = false;
+                               continue;
+                       }
+                       if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "pmux") {
+                               argidx++;
+                               allow_onehot = false;
+                               optimize_onehot = true;
+                               continue;
+                       }
+                       if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "shiftx") {
+                               argidx++;
                                allow_onehot = true;
+                               optimize_onehot = false;
+                               continue;
+                       }
+                       if (args[argidx] == "-v") {
+                               verbose = true;
+                               continue;
+                       }
+                       if (args[argidx] == "-vv") {
+                               verbose = true;
+                               verbose_onehot = true;
                                continue;
                        }
                        break;
@@ -75,10 +270,15 @@ struct Pmux2ShiftxPass : public Pass {
                for (auto module : design->selected_modules())
                {
                        SigMap sigmap(module);
+                       OnehotDatabase onehot_db(module, sigmap);
+                       onehot_db.verbose = verbose_onehot;
+
+                       if (optimize_onehot)
+                               onehot_db.initialize();
 
                        dict<SigBit, pair<SigSpec, Const>> eqdb;
 
-                       for (auto cell : module->selected_cells())
+                       for (auto cell : module->cells())
                        {
                                if (cell->type == "$eq")
                                {
@@ -181,6 +381,12 @@ struct Pmux2ShiftxPass : public Pass {
 
                                bool printed_pmux_header = false;
 
+                               if (verbose) {
+                                       printed_pmux_header = true;
+                                       log("Inspecting $pmux cell %s/%s.\n", log_id(module), log_id(cell));
+                                       log("  data width: %d (next power-of-2 = %d, log2 = %d)\n", width, extwidth, width_bits);
+                               }
+
                                SigSpec updated_S = cell->getPort("\\S");
                                SigSpec updated_B = cell->getPort("\\B");
 
@@ -196,7 +402,7 @@ struct Pmux2ShiftxPass : public Pass {
                                        }
 
                                        // find the relevant choices
-                                       bool is_onehot = true;
+                                       bool is_onehot = GetSize(sig) > 2;
                                        dict<Const, int> choices;
                                        for (int i : seldb.at(sig)) {
                                                Const val = eqdb.at(S[i]).second;
@@ -211,14 +417,17 @@ struct Pmux2ShiftxPass : public Pass {
 
                                        // TBD: also find choices that are using signals that are subsets of the bits in "sig"
 
-                                       if (is_onehot && !allow_onehot) {
-                                               seldb.erase(sig);
-                                               continue;
-                                       }
+                                       if (!verbose)
+                                       {
+                                               if (is_onehot && !allow_onehot && !optimize_onehot) {
+                                                       seldb.erase(sig);
+                                                       continue;
+                                               }
 
-                                       if (GetSize(choices) < min_choices) {
-                                               seldb.erase(sig);
-                                               continue;
+                                               if (GetSize(choices) < min_choices) {
+                                                       seldb.erase(sig);
+                                                       continue;
+                                               }
                                        }
 
                                        if (!printed_pmux_header) {
@@ -229,6 +438,65 @@ struct Pmux2ShiftxPass : public Pass {
 
                                        log("  checking ctrl signal %s\n", log_signal(sig));
 
+                                       auto print_choices = [&]() {
+                                               log("    table of choices:\n");
+                                               for (auto &it : choices)
+                                                       log("    %3d: %s: %s\n", it.second, log_signal(it.first),
+                                                                       log_signal(B.extract(it.second*width, width)));
+                                       };
+
+                                       if (verbose)
+                                       {
+                                               if (is_onehot && !allow_onehot && !optimize_onehot) {
+                                                       print_choices();
+                                                       log("    ignoring one-hot encoding.\n");
+                                                       seldb.erase(sig);
+                                                       continue;
+                                               }
+
+                                               if (GetSize(choices) < min_choices) {
+                                                       print_choices();
+                                                       log("    insufficient choices.\n");
+                                                       seldb.erase(sig);
+                                                       continue;
+                                               }
+                                       }
+
+                                       if (is_onehot && optimize_onehot)
+                                       {
+                                               print_choices();
+                                               if (!onehot_db.query(sig))
+                                               {
+                                                       log("    failed to detect onehot driver. do not optimize.\n");
+                                               }
+                                               else
+                                               {
+                                                       log("    optimizing one-hot encoding.\n");
+                                                       for (auto &it : choices)
+                                                       {
+                                                               const Const &val = it.first;
+                                                               int index = -1;
+
+                                                               for (int i = 0; i < GetSize(val); i++)
+                                                                       if (val[i] == State::S1) {
+                                                                               log_assert(index < 0);
+                                                                               index = i;
+                                                                       }
+
+                                                               if (index < 0) {
+                                                                       log("    %3d: zero encoding.\n", it.second);
+                                                                       continue;
+                                                               }
+
+                                                               SigBit new_ctrl = sig[index];
+                                                               log("    %3d: new crtl signal is %s.\n", it.second, log_signal(new_ctrl));
+                                                               updated_S[it.second] = new_ctrl;
+                                                       }
+                                               }
+                                               seldb.erase(sig);
+                                               continue;
+                                       }
+
                                        // find the best permutation
                                        vector<int> perm_new_from_old(GetSize(sig));
                                        Const perm_xormask(State::S0, GetSize(sig));
@@ -434,4 +702,127 @@ struct Pmux2ShiftxPass : public Pass {
        }
 } Pmux2ShiftxPass;
 
+struct OnehotPass : public Pass {
+       OnehotPass() : Pass("onehot", "optimize $eq cells for onehot signals") { }
+       void help() YS_OVERRIDE
+       {
+               //   |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
+               log("\n");
+               log("    onehot [options] [selection]\n");
+               log("\n");
+               log("This pass optimizes $eq cells that compare one-hot signals against constants\n");
+               log("\n");
+               log("    -v, -vv\n");
+               log("        verbose output\n");
+               log("\n");
+       }
+       void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
+       {
+               bool verbose = false;
+               bool verbose_onehot = false;
+
+               log_header(design, "Executing ONEHOT pass.\n");
+
+               size_t argidx;
+               for (argidx = 1; argidx < args.size(); argidx++) {
+                       if (args[argidx] == "-v") {
+                               verbose = true;
+                               continue;
+                       }
+                       if (args[argidx] == "-vv") {
+                               verbose = true;
+                               verbose_onehot = true;
+                               continue;
+                       }
+                       break;
+               }
+               extra_args(args, argidx, design);
+
+               for (auto module : design->selected_modules())
+               {
+                       SigMap sigmap(module);
+                       OnehotDatabase onehot_db(module, sigmap);
+                       onehot_db.verbose = verbose_onehot;
+                       onehot_db.initialize();
+
+                       for (auto cell : module->selected_cells())
+                       {
+                               if (cell->type != "$eq")
+                                       continue;
+
+                               SigSpec A = sigmap(cell->getPort("\\A"));
+                               SigSpec B = sigmap(cell->getPort("\\B"));
+
+                               int a_width = cell->getParam("\\A_WIDTH").as_int();
+                               int b_width = cell->getParam("\\B_WIDTH").as_int();
+
+                               if (a_width < b_width) {
+                                       bool a_signed = cell->getParam("\\A_SIGNED").as_int();
+                                       A.extend_u0(b_width, a_signed);
+                               }
+
+                               if (b_width < a_width) {
+                                       bool b_signed = cell->getParam("\\B_SIGNED").as_int();
+                                       B.extend_u0(a_width, b_signed);
+                               }
+
+                               if (A.is_fully_const())
+                                       std::swap(A, B);
+
+                               if (!B.is_fully_const())
+                                       continue;
+
+                               if (verbose)
+                                       log("Checking $eq(%s, %s) cell %s/%s.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell));
+
+                               if (!onehot_db.query(A)) {
+                                       if (verbose)
+                                               log("  onehot driver test on %s failed.\n", log_signal(A));
+                                       continue;
+                               }
+
+                               int index = -1;
+                               bool not_onehot = false;
+
+                               for (int i = 0; i < GetSize(B); i++) {
+                                       if (B[i] != State::S1)
+                                               continue;
+                                       if (index >= 0)
+                                               not_onehot = true;
+                                       index = i;
+                               }
+
+                               if (index < 0) {
+                                       if (verbose)
+                                               log("  not optimizing the zero pattern.\n");
+                                       continue;
+                               }
+
+                               SigSpec Y = cell->getPort("\\Y");
+
+                               if (not_onehot)
+                               {
+                                       if (verbose)
+                                               log("  replacing with constant 0 driver.\n");
+                                       else
+                                               log("Replacing one-hot $eq(%s, %s) cell %s/%s with constant 0 driver.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell));
+                                       module->connect(Y, SigSpec(1, GetSize(Y)));
+                               }
+                               else
+                               {
+                                       SigSpec sig = A[index];
+                                       if (verbose)
+                                               log("  replacing with signal %s.\n", log_signal(sig));
+                                       else
+                                               log("Replacing one-hot $eq(%s, %s) cell %s/%s with signal %s.\n",log_signal(A), log_signal(B), log_id(module), log_id(cell), log_signal(sig));
+                                       sig.extend_u0(GetSize(Y));
+                                       module->connect(Y, sig);
+                               }
+
+                               module->remove(cell);
+                       }
+               }
+       }
+} OnehotPass;
+
 PRIVATE_NAMESPACE_END