Major refactoring of equiv_struct
authorClifford Wolf <clifford@clifford.at>
Sun, 25 Oct 2015 18:31:29 +0000 (19:31 +0100)
committerClifford Wolf <clifford@clifford.at>
Sun, 25 Oct 2015 18:31:29 +0000 (19:31 +0100)
kernel/hashlib.h
passes/equiv/equiv_struct.cc

index 3cc95b6e4b5fa5cc5599f7dcca4277a67a458278..4f5a353c5ec713d2b1331343ec493eafc93714c5 100644 (file)
@@ -162,6 +162,11 @@ struct hash_obj_ops {
        }
 };
 
+template<typename T>
+inline unsigned int mkhash(const T &v) {
+       return hash_ops<T>().hash(v);
+}
+
 inline int hashtable_size(int min_size)
 {
        static std::vector<int> zero_and_some_primes = {
index eae6d0fcf0f4f2da1a7bc4318087262c338ad6c6..aeac70bb5477daef0d04abff232cf28c2a7e7044 100644 (file)
@@ -28,98 +28,100 @@ struct EquivStructWorker
        Module *module;
        SigMap sigmap;
        SigMap equiv_bits;
+       bool mode_fwd;
        bool mode_icells;
        int merge_count;
 
-       dict<IdString, pool<IdString>> cells_by_type;
+       struct merge_key_t
+       {
+               IdString type;
+               vector<pair<IdString, Const>> parameters;
+               vector<pair<IdString, int>> port_sizes;
+               vector<tuple<IdString, int, SigBit>> connections;
+
+               bool operator==(const merge_key_t &other) const {
+                       return type == other.type && connections == other.connections &&
+                                       parameters == other.parameters && port_sizes == other.port_sizes;
+               }
 
-       void handle_cell_pair(Cell *cell_a, Cell *cell_b)
+               unsigned int hash() const {
+                       unsigned int h = mkhash_init;
+                       h = mkhash(h, mkhash(type));
+                       h = mkhash(h, mkhash(parameters));
+                       h = mkhash(h, mkhash(connections));
+                       return h;
+               }
+       };
+
+       dict<merge_key_t, pool<IdString>> merge_cache;
+       pool<merge_key_t> fwd_merge_cache, bwd_merge_cache;
+
+       void merge_cell_pair(Cell *cell_a, Cell *cell_b)
        {
-               if (cell_a->parameters != cell_b->parameters)
-                       return;
+               SigMap merged_map;
+               merge_count++;
 
-               bool merge_this_cells = false;
-               bool found_diff_inputs = false;
-               vector<SigSpec> inputs_a, inputs_b;
+               SigSpec inputs_a, inputs_b;
+               vector<string> input_names;
 
                for (auto &port_a : cell_a->connections())
                {
-                       SigSpec bits_a = equiv_bits(port_a.second);
-                       SigSpec bits_b = equiv_bits(cell_b->getPort(port_a.first));
+                       SigSpec bits_a = sigmap(port_a.second);
+                       SigSpec bits_b = sigmap(cell_b->getPort(port_a.first));
 
-                       if (GetSize(bits_a) != GetSize(bits_b))
-                               return;
+                       log_assert(GetSize(bits_a) == GetSize(bits_b));
 
-                       if (cell_a->output(port_a.first)) {
-                               for (int i = 0; i < GetSize(bits_a); i++)
-                                       if (bits_a[i] == bits_b[i])
-                                               merge_this_cells = true;
-                       } else {
-                               SigSpec diff_bits_a, diff_bits_b;
+                       if (!cell_a->output(port_a.first))
                                for (int i = 0; i < GetSize(bits_a); i++)
                                        if (bits_a[i] != bits_b[i]) {
-                                               diff_bits_a.append(bits_a[i]);
-                                               diff_bits_b.append(bits_b[i]);
+                                               inputs_a.append(bits_a[i]);
+                                               inputs_b.append(bits_b[i]);
+                                               input_names.push_back(GetSize(bits_a) == 1 ? port_a.first.str() :
+                                                               stringf("%s[%d]", log_id(port_a.first), i));
                                        }
-                               if (!diff_bits_a.empty()) {
-                                       inputs_a.push_back(diff_bits_a);
-                                       inputs_b.push_back(diff_bits_b);
-                                       found_diff_inputs = true;
-                               }
-                       }
                }
 
-               if (!found_diff_inputs)
-                       merge_this_cells = true;
-
-               if (merge_this_cells)
-               {
-                       SigMap merged_map;
-
-                       log("      Merging cells %s and %s.\n", log_id(cell_a),  log_id(cell_b));
-                       merge_count++;
-
-                       for (int i = 0; i < GetSize(inputs_a); i++) {
-                               SigSpec &sig_a = inputs_a[i], &sig_b = inputs_b[i];
-                               SigSpec sig_y = module->addWire(NEW_ID, GetSize(sig_a));
-                               log("        A: %s, B: %s, Y: %s\n", log_signal(sig_a),  log_signal(sig_b), log_signal(sig_y));
-                               module->addEquiv(NEW_ID, sig_a, sig_b, sig_y);
-                               merged_map.add(sig_a, sig_y);
-                               merged_map.add(sig_b, sig_y);
-                       }
-
-                       std::vector<IdString> outport_names, inport_names;
+               for (int i = 0; i < GetSize(inputs_a); i++) {
+                       SigBit bit_a = inputs_a[i], bit_b = inputs_b[i];
+                       SigBit bit_y = module->addWire(NEW_ID);
+                       log("      New $equiv for input %s: A: %s, B: %s, Y: %s\n",
+                                       input_names[i].c_str(), log_signal(bit_a), log_signal(bit_b), log_signal(bit_y));
+                       module->addEquiv(NEW_ID, bit_a, bit_b, bit_y);
+                       merged_map.add(bit_a, bit_y);
+                       merged_map.add(bit_b, bit_y);
+               }
 
-                       for (auto &port_a : cell_a->connections())
-                               if (cell_a->output(port_a.first))
-                                       outport_names.push_back(port_a.first);
-                               else
-                                       inport_names.push_back(port_a.first);
+               std::vector<IdString> outport_names, inport_names;
 
-                       for (auto &pn : inport_names)
-                               cell_a->setPort(pn, merged_map(equiv_bits(cell_a->getPort(pn))));
+               for (auto &port_a : cell_a->connections())
+                       if (cell_a->output(port_a.first))
+                               outport_names.push_back(port_a.first);
+                       else
+                               inport_names.push_back(port_a.first);
 
-                       for (auto &pn : outport_names) {
-                               SigSpec sig_a = cell_a->getPort(pn);
-                               SigSpec sig_b = cell_b->getPort(pn);
-                               module->connect(sig_b, sig_a);
-                               sigmap.add(sig_b, sig_a);
-                               equiv_bits.add(sig_b, sig_a);
-                       }
+               for (auto &pn : inport_names)
+                       cell_a->setPort(pn, merged_map(sigmap(cell_a->getPort(pn))));
 
-                       auto merged_attr = cell_b->get_strpool_attribute("\\equiv_merged");
-                       merged_attr.insert(log_id(cell_b));
-                       cell_a->add_strpool_attribute("\\equiv_merged", merged_attr);
-                       module->remove(cell_b);
+               for (auto &pn : outport_names) {
+                       SigSpec sig_a = cell_a->getPort(pn);
+                       SigSpec sig_b = cell_b->getPort(pn);
+                       module->connect(sig_b, sig_a);
                }
+
+               auto merged_attr = cell_b->get_strpool_attribute("\\equiv_merged");
+               merged_attr.insert(log_id(cell_b));
+               cell_a->add_strpool_attribute("\\equiv_merged", merged_attr);
+               module->remove(cell_b);
        }
 
-       EquivStructWorker(Module *module, bool mode_icells) :
-                       module(module), sigmap(module), equiv_bits(module), mode_icells(mode_icells), merge_count(0)
+       EquivStructWorker(Module *module, bool mode_fwd, bool mode_icells) :
+                       module(module), sigmap(module), equiv_bits(module),
+                       mode_fwd(mode_fwd), mode_icells(mode_icells), merge_count(0)
        {
                log("  Starting new iteration.\n");
 
                pool<SigBit> equiv_inputs;
+               pool<IdString> cells;
 
                for (auto cell : module->selected_cells())
                        if (cell->type == "$equiv") {
@@ -128,45 +130,104 @@ struct EquivStructWorker
                                equiv_bits.add(sig_b, sig_a);
                                equiv_inputs.insert(sig_a);
                                equiv_inputs.insert(sig_b);
-                               cells_by_type[cell->type].insert(cell->name);
-                       } else
-                       if (module->design->selected(module, cell)) {
+                               cells.insert(cell->name);
+                       } else {
                                if (mode_icells || module->design->module(cell->type))
-                                       cells_by_type[cell->type].insert(cell->name);
+                                       cells.insert(cell->name);
                        }
 
-               for (auto cell_name : cells_by_type["$equiv"]) {
-                       Cell *cell = module->cell(cell_name);
-                       SigBit sig_a = sigmap(cell->getPort("\\A").as_bit());
-                       SigBit sig_b = sigmap(cell->getPort("\\B").as_bit());
-                       SigBit sig_y = sigmap(cell->getPort("\\Y").as_bit());
-                       if (sig_a == sig_b && equiv_inputs.count(sig_y)) {
-                               log("    Purging redundant $equiv cell %s.\n", log_id(cell));
-                               module->remove(cell);
-                               merge_count++;
+               for (auto cell : module->selected_cells())
+                       if (cell->type == "$equiv") {
+                               SigBit sig_a = sigmap(cell->getPort("\\A").as_bit());
+                               SigBit sig_b = sigmap(cell->getPort("\\B").as_bit());
+                               SigBit sig_y = sigmap(cell->getPort("\\Y").as_bit());
+                               if (sig_a == sig_b && equiv_inputs.count(sig_y)) {
+                                       log("    Purging redundant $equiv cell %s.\n", log_id(cell));
+                                       module->remove(cell);
+                                       merge_count++;
+                               }
                        }
-               }
 
                if (merge_count > 0)
                        return;
 
-               for (auto &it : cells_by_type)
+               for (auto cell_name : cells)
                {
-                       if (it.second.size() <= 1)
-                               continue;
+                       merge_key_t key;
+                       vector<tuple<IdString, int, SigBit>> fwd_connections;
+
+                       Cell *cell = module->cell(cell_name);
+                       key.type = cell->type;
+
+                       for (auto &it : cell->parameters)
+                               key.parameters.push_back(it);
+                       std::sort(key.parameters.begin(), key.parameters.end());
+
+                       for (auto &it : cell->connections())
+                               key.port_sizes.push_back(make_pair(it.first, GetSize(it.second)));
+                       std::sort(key.port_sizes.begin(), key.port_sizes.end());
+
+                       for (auto &conn : cell->connections())
+                       {
+                               SigSpec sig = equiv_bits(conn.second);
+
+                               if (cell->input(conn.first))
+                                       for (int i = 0; i < GetSize(sig); i++)
+                                               fwd_connections.push_back(make_tuple(conn.first, i, sig[i]));
+
+                               if (cell->output(conn.first))
+                                       for (int i = 0; i < GetSize(sig); i++) {
+                                               key.connections.clear();
+                                               key.connections.push_back(make_tuple(conn.first, i, sig[i]));
+
+                                               if (merge_cache.count(key))
+                                                       bwd_merge_cache.insert(key);
+                                               merge_cache[key].insert(cell_name);
+                                       }
+                       }
+
+                       std::sort(fwd_connections.begin(), fwd_connections.end());
+                       key.connections.swap(fwd_connections);
+
+                       if (merge_cache.count(key))
+                               fwd_merge_cache.insert(key);
+                       merge_cache[key].insert(cell_name);
+               }
+
+               for (int phase = 0; phase < 2; phase++)
+               {
+                       auto &queue = phase ? bwd_merge_cache : fwd_merge_cache;
 
-                       log("    Merging %s cells..\n", log_id(it.first));
+                       for (auto &key : queue)
+                       {
+                               Cell *gold_cell = nullptr;
+                               pool<Cell*> cells;
 
-                       // FIXME: O(n^2)
-                       for (auto cell_name_a : it.second)
-                       for (auto cell_name_b : it.second)
-                               if (cell_name_a < cell_name_b) {
-                                       Cell *cell_a = module->cell(cell_name_a);
-                                       Cell *cell_b = module->cell(cell_name_b);
-                                       if (cell_a && cell_b)
-                                               handle_cell_pair(cell_a, cell_b);
+                               for (auto cell_name : merge_cache[key]) {
+                                       Cell *c = module->cell(cell_name);
+                                       if (c != nullptr) {
+                                               string n = cell_name.str();
+                                               if (gold_cell == nullptr || (GetSize(n) > 5 && n.substr(GetSize(n)-5) == "_gold"))
+                                                       gold_cell = c;
+                                               cells.insert(c);
+                                       }
                                }
+
+                               if (GetSize(cells) < 2)
+                                       continue;
+
+                               for (auto gate_cell : cells)
+                                       if (gate_cell != gold_cell) {
+                                               log("    %s merging cells %s and %s.\n", phase ? "Bwd" : "Fwd", log_id(gold_cell),  log_id(gate_cell));
+                                               merge_cell_pair(gold_cell, gate_cell);
+                                       }
+                       }
+
+                       if (merge_count > 0)
+                               return;
                }
+
+               log("    Nothing to merge.\n");
        }
 };
 
@@ -184,6 +245,12 @@ struct EquivStructPass : public Pass {
                log("for example when analyzing circuits with cells with commutative inputs. This\n");
                log("command will also de-duplicate gates.\n");
                log("\n");
+               log("    -fwd\n");
+               log("        by default this command performans forward sweeps until nothing can\n");
+               log("        be merged by forwards sweeps, the backward sweeps until forward\n");
+               log("        sweeps are effective again. with this option set only forward sweeps\n");
+               log("        are performed.\n");
+               log("\n");
                log("    -icells\n");
                log("        by default, the internal RTL and gate cell types are ignored. add\n");
                log("        this option to also process those cell types with this command.\n");
@@ -192,11 +259,16 @@ struct EquivStructPass : public Pass {
        virtual void execute(std::vector<std::string> args, Design *design)
        {
                bool mode_icells = false;
+               bool mode_fwd = false;
 
                log_header("Executing EQUIV_STRUCT pass.\n");
 
                size_t argidx;
                for (argidx = 1; argidx < args.size(); argidx++) {
+                       if (args[argidx] == "-fwd") {
+                               mode_fwd = true;
+                               continue;
+                       }
                        if (args[argidx] == "-icells") {
                                mode_icells = true;
                                continue;
@@ -206,9 +278,9 @@ struct EquivStructPass : public Pass {
                extra_args(args, argidx, design);
 
                for (auto module : design->selected_modules()) {
-                       log("Running equiv_struct on module %s:", log_id(module));
+                       log("Running equiv_struct on module %s:\n", log_id(module));
                        while (1) {
-                               EquivStructWorker worker(module, mode_icells);
+                               EquivStructWorker worker(module, mode_fwd, mode_icells);
                                if (worker.merge_count == 0)
                                        break;
                        }