Complete rewrite of pmux2shiftx
authorClifford Wolf <clifford@clifford.at>
Fri, 19 Apr 2019 16:10:12 +0000 (18:10 +0200)
committerClifford Wolf <clifford@clifford.at>
Fri, 19 Apr 2019 22:38:25 +0000 (00:38 +0200)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/opt/pmux2shiftx.cc

index 6ffc27a4cb5c0a12f60e46ba86174a1d48d03150..8ea70fe84eab779c0a63b50345c23b61cbfcf9d4 100644 (file)
@@ -45,35 +45,273 @@ struct Pmux2ShiftxPass : public Pass {
                extra_args(args, argidx, design);
 
                for (auto module : design->selected_modules())
-               for (auto cell : module->selected_cells())
                {
-                       if (cell->type != "$pmux")
-                               continue;
-
-                       // Create a new encoder, out of a $pmux, that takes
-                       // the existing pmux's 'S' input and transforms it
-                       // back into a binary value
-                       RTLIL::SigSpec shiftx_a;
-                       RTLIL::SigSpec pmux_s;
-
-                       int s_width = cell->getParam("\\S_WIDTH").as_int();
-                       if (!cell->getPort("\\A").is_fully_undef()) {
-                               ++s_width;
-                               shiftx_a.append(cell->getPort("\\A"));
-                               pmux_s.append(module->Not(NEW_ID, module->ReduceOr(NEW_ID, cell->getPort("\\S"))));
+                       SigMap sigmap(module);
+
+                       dict<SigBit, pair<SigSpec, Const>> eqdb;
+
+                       for (auto cell : module->selected_cells())
+                       {
+                               if (cell->type == "$eq")
+                               {
+                                       dict<SigBit, State> bits;
+
+                                       SigSpec A = sigmap(cell->getPort("\\A"));
+                                       SigSpec B = sigmap(cell->getPort("\\B"));
+
+                                       int a_width = cell->getParam("\\A_WIDTH").as_int();
+                                       int b_width = cell->getParam("\\B_WIDTH").as_int();
+
+                                       if (a_width < b_width) {
+                                               bool a_signed = cell->getParam("\\A_SIGNED").as_int();
+                                               A.extend_u0(b_width, a_signed);
+                                       }
+
+                                       if (b_width < a_width) {
+                                               bool b_signed = cell->getParam("\\B_SIGNED").as_int();
+                                               B.extend_u0(a_width, b_signed);
+                                       }
+
+                                       for (int i = 0; i < GetSize(A); i++) {
+                                               SigBit a_bit = A[i], b_bit = B[i];
+                                               if (b_bit.wire && !a_bit.wire) {
+                                                       std::swap(a_bit, b_bit);
+                                               }
+                                               if (!a_bit.wire || b_bit.wire)
+                                                       goto next_cell;
+                                               if (bits.count(a_bit))
+                                                       goto next_cell;
+                                               bits[a_bit] = b_bit.data;
+                                       }
+
+                                       if (GetSize(bits) > 20)
+                                               goto next_cell;
+
+                                       bits.sort();
+                                       pair<SigSpec, Const> entry;
+
+                                       for (auto it : bits) {
+                                               entry.first.append_bit(it.first);
+                                               entry.second.bits.push_back(it.second);
+                                       }
+
+                                       eqdb[sigmap(cell->getPort("\\Y")[0])] = entry;
+                                       goto next_cell;
+                               }
+
+                               if (cell->type == "$logic_not")
+                               {
+                                       dict<SigBit, State> bits;
+
+                                       SigSpec A = sigmap(cell->getPort("\\A"));
+
+                                       for (int i = 0; i < GetSize(A); i++)
+                                               bits[A[i]] = State::S0;
+
+                                       bits.sort();
+                                       pair<SigSpec, Const> entry;
+
+                                       for (auto it : bits) {
+                                               entry.first.append_bit(it.first);
+                                               entry.second.bits.push_back(it.second);
+                                       }
+
+                                       eqdb[sigmap(cell->getPort("\\Y")[0])] = entry;
+                                       goto next_cell;
+                               }
+               next_cell:;
+                       }
+
+                       for (auto cell : module->selected_cells())
+                       {
+                               if (cell->type != "$pmux")
+                                       continue;
+
+                               string src = cell->get_src_attribute();
+                               int width = cell->getParam("\\WIDTH").as_int();
+                               int width_bits = ceil_log2(width);
+                               int extwidth = width;
+
+                               while (extwidth & (extwidth-1))
+                                       extwidth++;
+
+                               dict<SigSpec, pool<int>> seldb;
+
+                               SigSpec S = sigmap(cell->getPort("\\S"));
+                               for (int i = 0; i < GetSize(S); i++)
+                               {
+                                       if (!eqdb.count(S[i]))
+                                               continue;
+
+                                       auto &entry = eqdb.at(S[i]);
+                                       seldb[entry.first].insert(i);
+                               }
+
+                               if (seldb.empty())
+                                       continue;
+
+                               log("Inspecting $pmux cell %s/%s.\n", log_id(module), log_id(cell));
+                               log("  data width: %d (next power-of-2 = %d, log2 = %d)\n", width, extwidth, width_bits);
+
+                               SigSpec updated_S = cell->getPort("\\S");
+                               SigSpec updated_B = cell->getPort("\\B");
+
+                       #if 1
+                               for (auto &it : seldb) {
+                                       string msg = stringf("seldb: %s ->", log_signal(it.first));
+                                       for (int i : it.second)
+                                               msg += stringf(" %d(%s)", i, log_signal(eqdb.at(S[i]).second));
+                                       log("  %s\n", msg.c_str());
+                               }
+                       #endif
+
+                               while (!seldb.empty())
+                               {
+                                       // pick the largest entry in seldb
+                                       SigSpec sig = seldb.begin()->first;
+                                       for (auto &it : seldb) {
+                                               if (GetSize(sig) < GetSize(it.first))
+                                                       sig = it.first;
+                                               else if (GetSize(seldb.at(sig)) < GetSize(it.second))
+                                                       sig = it.first;
+                                       }
+
+                                       log("  checking ctrl signal %s\n", log_signal(sig));
+
+                                       // find the relevant choices
+                                       dict<Const, int> choices;
+                                       vector<int> onescnt(GetSize(sig));
+                                       for (int i : seldb.at(sig)) {
+                                               Const val = eqdb.at(S[i]).second;
+                                               choices[val] = i;
+                                               for (int k = 0; k < GetSize(val); k++)
+                                                       if (val[k] == State::S1)
+                                                               onescnt[k] |= 1;
+                                                       else
+                                                               onescnt[k] |= 2;
+                                       }
+
+                                       // TBD: also find choices that are using signals that are subsets of the bits in "sig"
+
+                                       // find the best permutation
+                                       vector<pair<int, int>> perm(GetSize(sig));
+                                       for (int i = 0; i < GetSize(onescnt); i++)
+                                               perm[i] = make_pair(onescnt[i], i);
+                                       // TBD: this is not the best permutation
+                                       std::sort(perm.rbegin(), perm.rend());
+
+                                       // permutated sig
+                                       Const perm_xormask(State::S0, GetSize(sig));
+                                       SigSpec perm_sig(State::S0, GetSize(sig));
+                                       for (int i = 0; i < GetSize(sig); i++) {
+                                               if (perm[i].first == 1)
+                                                       perm_xormask[i] = State::S1;
+                                               perm_sig[i] = sig[perm[i].second];
+                                       }
+
+                                       log("    best permutation: %s\n", log_signal(perm_sig));
+                                       log("    best xor mask: %s\n", log_signal(perm_xormask));
+
+                                       // permutated choices
+                                       int min_choice = 1 << 30;
+                                       int max_choice = -1;
+                                       dict<Const, int> perm_choices;
+
+                                       for (auto &it : choices)
+                                       {
+                                               Const &old_c = it.first;
+                                               Const new_c(State::S0, GetSize(old_c));
+
+                                               for (int i = 0; i < GetSize(old_c); i++)
+                                                       new_c[i] = old_c[perm[i].second];
+
+                                               Const new_c_before_xor = new_c;
+                                               new_c = const_xor(new_c, perm_xormask, false, false, GetSize(new_c));
+
+                                               perm_choices[new_c] = it.second;
+
+                                               min_choice = std::min(min_choice, new_c.as_int());
+                                               max_choice = std::max(max_choice, new_c.as_int());
+
+                                               log("      %s -> %s -> %s\n", log_signal(old_c), log_signal(new_c_before_xor), log_signal(new_c));
+                                       }
+
+                                       log("    choices: %d\n", GetSize(choices));
+                                       log("    min choice: %d\n", min_choice);
+                                       log("    max choice: %d\n", max_choice);
+                                       log("    range density: %d%%\n", 100*GetSize(choices)/(max_choice-min_choice+1));
+                                       log("    absolute density: %d%%\n", 100*GetSize(choices)/(max_choice+1));
+
+                                       bool full_case = (min_choice == 0) && (max_choice == (1 << GetSize(sig))-1) && (max_choice+1 == GetSize(choices));
+                                       log("    full case: %s\n", full_case ? "true" : "false");
+
+                                       // use arithmetic offset if density is less than 30%
+                                       Const offset(State::S0, GetSize(sig));
+                                       if (3*GetSize(choices) < max_choice && 3*GetSize(choices) >= (max_choice-min_choice))
+                                       {
+                                               log("    using offset method.\n");
+
+                                               offset = Const(min_choice, GetSize(sig));
+                                               min_choice -= offset.as_int();
+                                               max_choice -= offset.as_int();
+
+                                               dict<Const, int> new_perm_choices;
+                                               for (auto &it : perm_choices)
+                                                       new_perm_choices[const_sub(it.first, offset, false, false, GetSize(sig))] = it.second;
+                                               perm_choices.swap(new_perm_choices);
+                                       }
+
+                                       // ignore cases with a absolute density of less than 30%
+                                       if (3*GetSize(choices) < max_choice) {
+                                               log("    insufficient density.\n");
+                                               seldb.erase(sig);
+                                               continue;
+                                       }
+
+                                       // creat cmp signal
+                                       SigSpec cmp = perm_sig;
+                                       if (perm_xormask.as_bool())
+                                               cmp = module->Xor(NEW_ID, cmp, perm_xormask, false, src);
+                                       if (offset.as_bool())
+                                               cmp = module->Sub(NEW_ID, cmp, offset, false, src);
+
+                                       // create enable signal
+                                       SigBit en = State::S1;
+                                       if (!full_case) {
+                                               Const enable_mask(State::S0, max_choice+1);
+                                               for (auto &it : perm_choices)
+                                                       enable_mask[it.first.as_int()] = State::S1;
+                                               en = module->addWire(NEW_ID);
+                                               module->addShift(NEW_ID, enable_mask, cmp, en, false, src);
+                                       }
+
+                                       // create data signal
+                                       SigSpec data(State::Sx, (max_choice+1)*extwidth);
+                                       for (auto &it : perm_choices) {
+                                               int position = it.first.as_int()*extwidth;
+                                               int data_index = it.second;
+                                               data.replace(position, cell->getPort("\\B").extract(data_index*width, width));
+                                               updated_S[data_index] = State::S0;
+                                               updated_B.replace(data_index*width, SigSpec(State::Sx, width));
+                                       }
+
+                                       // create shiftx cell
+                                       SigSpec shifted_cmp = {cmp, SigSpec(State::S0, width_bits)};
+                                       SigSpec outsig = module->addWire(NEW_ID, width);
+                                       Cell *c = module->addShiftx(NEW_ID, data, shifted_cmp, outsig, false, src);
+                                       updated_S.append(en);
+                                       updated_B.append(outsig);
+                                       log("    created $shiftx cell %s.\n", log_id(c));
+
+                                       // remove this sig and continue with the next block
+                                       seldb.erase(sig);
+                               }
+
+                               // update $pmux cell
+                               cell->setPort("\\S", updated_S);
+                               cell->setPort("\\B", updated_B);
+                               cell->setParam("\\S_WIDTH", GetSize(updated_S));
                        }
-                       const int clog2width = ceil(log2(s_width));
-
-                       RTLIL::SigSpec pmux_b;
-                       for (int i = s_width-1; i >= 0; i--)
-                               pmux_b.append(RTLIL::Const(i, clog2width));
-                       shiftx_a.append(cell->getPort("\\B"));
-                       pmux_s.append(cell->getPort("\\S"));
-
-                       RTLIL::SigSpec pmux_y = module->addWire(NEW_ID, clog2width);
-                       module->addPmux(NEW_ID, RTLIL::Const(RTLIL::Sx, clog2width), pmux_b, pmux_s, pmux_y);
-                       module->addShiftx(NEW_ID, shiftx_a, pmux_y, cell->getPort("\\Y"));
-                       module->remove(cell);
                }
        }
 } Pmux2ShiftxPass;