Add more mutation types, improve mutation src cover
authorClifford Wolf <clifford@clifford.at>
Thu, 14 Mar 2019 18:52:02 +0000 (19:52 +0100)
committerClifford Wolf <clifford@clifford.at>
Thu, 14 Mar 2019 21:04:42 +0000 (22:04 +0100)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/sat/mutate.cc

index 3ecb955f02770ba1a41cfcfa31c71446f1f26b5f..eac00948ab459a611de40522e26d793bf8f3e948 100644 (file)
@@ -24,20 +24,24 @@ USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
 
 struct mutate_t {
-       std::string mode, src;
-       Module *module;
-       Cell *cell;
-       IdString cellport;
-       SigBit outsigbit;
+       string mode;
+       pool<string> src;
+       IdString module, cell;
+       IdString port, wire;
        int portbit = -1;
+       int ctrlbit = -1;
+       int wirebit = -1;
        bool used = false;
 };
 
 struct mutate_opts_t {
        int seed = 0;
        std::string mode;
-       IdString module, cell, port;
-       int bit = -1;
+       pool<string> src;
+       IdString module, cell, port, wire;
+       int portbit = -1;
+       int ctrlbit = -1;
+       int wirebit = -1;
        IdString ctrl_name;
        int ctrl_width, ctrl_value;
 };
@@ -47,16 +51,35 @@ void database_add(std::vector<mutate_t> &database, const mutate_opts_t &opts, co
        if (!opts.mode.empty() && opts.mode != entry.mode)
                return;
 
-       if (!opts.module.empty() && opts.module != entry.module->name)
+       if (!opts.src.empty()) {
+               bool found_match = false;
+               for (auto &s : opts.src) {
+                       if (entry.src.count(s))
+                               found_match = true;
+               }
+               if (!found_match)
+                       return;
+       }
+
+       if (!opts.module.empty() && opts.module != entry.module)
+               return;
+
+       if (!opts.cell.empty() && opts.cell != entry.cell)
+               return;
+
+       if (!opts.port.empty() && opts.port != entry.port)
+               return;
+
+       if (opts.portbit >= 0 && opts.portbit != entry.portbit)
                return;
 
-       if (!opts.cell.empty() && opts.cell != entry.cell->name)
+       if (opts.ctrlbit >= 0 && opts.ctrlbit != entry.ctrlbit)
                return;
 
-       if (!opts.port.empty() && opts.port != entry.cellport)
+       if (!opts.wire.empty() && opts.wire != entry.wire)
                return;
 
-       if (opts.bit >= 0 && opts.bit != entry.portbit)
+       if (opts.wirebit >= 0 && opts.wirebit != entry.wirebit)
                return;
 
        database.push_back(entry);
@@ -99,18 +122,48 @@ struct mutate_leaf_queue_t
 {
        pool<mutate_t*, hash_ptr_ops> db;
 
-       mutate_t *pick(xs128_t &rng) {
-               while (!db.empty()) {
-                       int i = rng(GetSize(db));
-                       auto it = db.element(i);
-                       mutate_t *m = *it;
-                       db.erase(it);
-                       if (m->used == false) {
-                               m->used = true;
-                               return m;
+       mutate_t *pick(xs128_t &rng, dict<string, int> &coverdb) {
+               mutate_t *m = nullptr;
+               if (rng(3)) {
+                       vector<mutate_t*> candidates;
+                       int best_score = -1;
+                       for (auto p : db) {
+                               if (p->used || p->src.empty())
+                                       continue;
+                               int this_score = -1;
+                               for (auto &s : p->src) {
+                                       if (this_score == -1 || this_score > coverdb.at(s))
+                                               this_score = coverdb.at(s);
+                               }
+                               log_assert(this_score != -1);
+                               if (best_score == -1 || this_score < best_score) {
+                                       candidates.clear();
+                                       best_score = this_score;
+                               }
+                               if (best_score == this_score)
+                                       candidates.push_back(p);
                        }
+                       if (!candidates.empty())
+                               m = candidates[rng(GetSize(candidates))];
                }
-               return nullptr;
+               if (m == nullptr) {
+                       while (!db.empty()) {
+                               int i = rng(GetSize(db));
+                               auto it = db.element(i);
+                               mutate_t *p = *it;
+                               db.erase(it);
+                               if (p->used == false) {
+                                       m = p;
+                                       break;
+                               }
+                       }
+               }
+               if (m != nullptr) {
+                       m->used = true;
+                       for (auto &s : m->src)
+                               coverdb[s]++;
+               }
+               return m;
        }
 
        void add(mutate_t *m) {
@@ -123,11 +176,11 @@ struct mutate_inner_queue_t
 {
        dict<K, T> db;
 
-       mutate_t *pick(xs128_t &rng) {
+       mutate_t *pick(xs128_t &rng, dict<string, int> &coverdb) {
                while (!db.empty()) {
                        int i = rng(GetSize(db));
                        auto it = db.element(i);
-                       mutate_t *m = it->second.pick(rng);
+                       mutate_t *m = it->second.pick(rng, coverdb);
                        if (m != nullptr)
                                return m;
                        db.erase(it);
@@ -141,66 +194,67 @@ struct mutate_inner_queue_t
        }
 };
 
-void database_reduce(std::vector<mutate_t> &database, const mutate_opts_t &opts, int N)
+void database_reduce(std::vector<mutate_t> &database, const mutate_opts_t &/* opts */, int N, xs128_t &rng)
 {
+       std::vector<mutate_t> new_database;
+       dict<string, int> coverdb;
+
+       int weight_pq_w = 100;
+       int weight_pq_b = 100;
+       int weight_pq_c = 100;
+       int weight_pq_s = 100;
+
+       int weight_pq_mw = 100;
+       int weight_pq_mb = 100;
+       int weight_pq_mc = 100;
+       int weight_pq_ms = 100;
+
+       int total_weight = weight_pq_w + weight_pq_b + weight_pq_c + weight_pq_s;
+       total_weight += weight_pq_mw + weight_pq_mb + weight_pq_mc + weight_pq_ms;
+
        if (N >= GetSize(database))
                return;
 
-       mutate_inner_queue_t<Wire*, mutate_leaf_queue_t> primary_queue_wire;
-       mutate_inner_queue_t<SigBit, mutate_leaf_queue_t> primary_queue_bit;
-       mutate_inner_queue_t<Cell*, mutate_leaf_queue_t> primary_queue_cell;
+       mutate_inner_queue_t<IdString, mutate_leaf_queue_t> primary_queue_wire;
+       mutate_inner_queue_t<pair<IdString, int>, mutate_leaf_queue_t> primary_queue_bit;
+       mutate_inner_queue_t<IdString, mutate_leaf_queue_t> primary_queue_cell;
        mutate_inner_queue_t<string, mutate_leaf_queue_t> primary_queue_src;
 
-       mutate_inner_queue_t<Module*, mutate_inner_queue_t<Wire*, mutate_leaf_queue_t>> primary_queue_module_wire;
-       mutate_inner_queue_t<Module*, mutate_inner_queue_t<SigBit, mutate_leaf_queue_t>> primary_queue_module_bit;
-       mutate_inner_queue_t<Module*, mutate_inner_queue_t<Cell*, mutate_leaf_queue_t>> primary_queue_module_cell;
-       mutate_inner_queue_t<Module*, mutate_inner_queue_t<string, mutate_leaf_queue_t>> primary_queue_module_src;
+       mutate_inner_queue_t<IdString, mutate_inner_queue_t<IdString, mutate_leaf_queue_t>> primary_queue_module_wire;
+       mutate_inner_queue_t<IdString, mutate_inner_queue_t<pair<IdString, int>, mutate_leaf_queue_t>> primary_queue_module_bit;
+       mutate_inner_queue_t<IdString, mutate_inner_queue_t<IdString, mutate_leaf_queue_t>> primary_queue_module_cell;
+       mutate_inner_queue_t<IdString, mutate_inner_queue_t<string, mutate_leaf_queue_t>> primary_queue_module_src;
 
        for (auto &m : database)
        {
-               if (m.outsigbit.wire) {
-                       primary_queue_wire.add(&m, m.outsigbit.wire);
-                       primary_queue_bit.add(&m, m.outsigbit);
-                       primary_queue_module_wire.add(&m, m.module, m.outsigbit.wire);
-                       primary_queue_module_bit.add(&m, m.module, m.outsigbit);
+               if (!m.wire.empty()) {
+                       primary_queue_wire.add(&m, m.wire);
+                       primary_queue_bit.add(&m, pair<IdString, int>(m.wire, m.wirebit));
+                       primary_queue_module_wire.add(&m, m.module, m.wire);
+                       primary_queue_module_bit.add(&m, m.module, pair<IdString, int>(m.wire, m.wirebit));
                }
 
                primary_queue_cell.add(&m, m.cell);
                primary_queue_module_cell.add(&m, m.module, m.cell);
 
-               if (!m.src.empty()) {
-                       primary_queue_src.add(&m, m.src);
-                       primary_queue_module_src.add(&m, m.module, m.src);
+               for (auto &s : m.src) {
+                       coverdb[s] = 0;
+                       primary_queue_src.add(&m, s);
+                       primary_queue_module_src.add(&m, m.module, s);
                }
        }
 
-       int weight_pq_w = 100;
-       int weight_pq_b = 100;
-       int weight_pq_c = 100;
-       int weight_pq_s = 100;
-
-       int weight_pq_mw = 100;
-       int weight_pq_mb = 100;
-       int weight_pq_mc = 100;
-       int weight_pq_ms = 100;
-
-       int total_weight = weight_pq_w + weight_pq_b + weight_pq_c + weight_pq_s;
-       total_weight += weight_pq_mw + weight_pq_mb + weight_pq_mc + weight_pq_ms;
-
-       std::vector<mutate_t> new_database;
-       xs128_t rng(opts.seed);
-
        while (GetSize(new_database) < N)
        {
                int k = rng(total_weight);
 
-#define X(__wght, __queue)                \
-    k -= __wght;                          \
-    if (k < 0) {                          \
-      mutate_t *m = __queue.pick(rng);    \
-      if (m != nullptr)                   \
-        new_database.push_back(*m);       \
-      continue;                           \
+#define X(__wght, __queue)                         \
+    k -= __wght;                                   \
+    if (k < 0) {                                   \
+      mutate_t *m = __queue.pick(rng, coverdb);    \
+      if (m != nullptr)                            \
+        new_database.push_back(*m);                \
+      continue;                                    \
     }
 
                X(weight_pq_w, primary_queue_wire)
@@ -215,11 +269,19 @@ void database_reduce(std::vector<mutate_t> &database, const mutate_opts_t &opts,
        }
 
        std::swap(new_database, database);
+
+       int covered_cnt = 0;
+       for (auto &it : coverdb)
+               if (it.second)
+                       covered_cnt++;
+
+       log("Covered %d/%d src attributes (%.2f%%).\n", covered_cnt, GetSize(coverdb), 100.0 * covered_cnt / GetSize(coverdb));
 }
 
 void mutate_list(Design *design, const mutate_opts_t &opts, const string &filename, int N)
 {
        std::vector<mutate_t> database;
+       xs128_t rng(opts.seed);
 
        for (auto module : design->selected_modules())
        {
@@ -260,20 +322,40 @@ void mutate_list(Design *design, const mutate_opts_t &opts, const string &filena
                        {
                                for (int i = 0; i < GetSize(conn.second); i++) {
                                        mutate_t entry;
-                                       entry.mode = "inv";
-                                       entry.src = cell->get_src_attribute();
-                                       entry.module = module;
-                                       entry.cell = cell;
-                                       entry.cellport = conn.first;
+                                       entry.module = module->name;
+                                       entry.cell = cell->name;
+                                       entry.port = conn.first;
                                        entry.portbit = i;
 
-                                       if (cell->output(conn.first)) {
-                                               SigBit bit = sigmap(conn.second[i]);
-                                               if (bit.wire && bit.wire->name[0] == '\\')
-                                                       entry.outsigbit = bit;
+                                       for (auto &s : cell->get_strpool_attribute("\\src"))
+                                               entry.src.insert(s);
+
+                                       SigBit bit = sigmap(conn.second[i]);
+                                       if (bit.wire && bit.wire->name[0] == '\\') {
+                                               for (auto &s : bit.wire->get_strpool_attribute("\\src"))
+                                                       entry.src.insert(s);
+                                               entry.wire = bit.wire->name;
+                                               entry.wirebit = bit.offset;
                                        }
 
+                                       entry.mode = "inv";
+                                       database_add(database, opts, entry);
+
+                                       entry.mode = "const0";
                                        database_add(database, opts, entry);
+
+                                       entry.mode = "const1";
+                                       database_add(database, opts, entry);
+
+                                       entry.mode = "cnot0";
+                                       entry.ctrlbit = rng(GetSize(conn.second));
+                                       if (entry.ctrlbit != entry.portbit && conn.second[entry.ctrlbit].wire)
+                                               database_add(database, opts, entry);
+
+                                       entry.mode = "cnot1";
+                                       entry.ctrlbit = rng(GetSize(conn.second));
+                                       if (entry.ctrlbit != entry.portbit && conn.second[entry.ctrlbit].wire)
+                                               database_add(database, opts, entry);
                                }
                        }
                }
@@ -281,7 +363,7 @@ void mutate_list(Design *design, const mutate_opts_t &opts, const string &filena
 
        log("Raw database size: %d\n", GetSize(database));
        if (N != 0) {
-               database_reduce(database, opts, N);
+               database_reduce(database, opts, N, rng);
                log("Reduced database size: %d\n", GetSize(database));
        }
 
@@ -300,21 +382,22 @@ void mutate_list(Design *design, const mutate_opts_t &opts, const string &filena
                if (!opts.ctrl_name.empty())
                        str += stringf(" -ctrl %s %d %d", log_id(opts.ctrl_name), opts.ctrl_width, ctrl_value++);
                str += stringf(" -mode %s", entry.mode.c_str());
-               if (entry.module)
+               if (!entry.module.empty())
                        str += stringf(" -module %s", log_id(entry.module));
-               if (entry.cell)
+               if (!entry.cell.empty())
                        str += stringf(" -cell %s", log_id(entry.cell));
-               if (!entry.cellport.empty())
-                       str += stringf(" -port %s", log_id(entry.cellport));
+               if (!entry.port.empty())
+                       str += stringf(" -port %s", log_id(entry.port));
                if (entry.portbit >= 0)
-                       str += stringf(" -bit %d", entry.portbit);
-               if (entry.outsigbit.wire || !entry.src.empty()) {
-                       str += " #";
-                       if (!entry.src.empty())
-                               str += stringf(" %s", entry.src.c_str());
-                       if (entry.outsigbit.wire)
-                               str += stringf(" %s", log_signal(entry.outsigbit));
-               }
+                       str += stringf(" -portbit %d", entry.portbit);
+               if (entry.ctrlbit >= 0)
+                       str += stringf(" -ctrlbit %d", entry.ctrlbit);
+               if (!entry.wire.empty())
+                       str += stringf(" -wire %s", log_id(entry.wire));
+               if (entry.wirebit >= 0)
+                       str += stringf(" -wirebit %d", entry.wirebit);
+               for (auto &s : entry.src)
+                       str += stringf(" -src %s", s.c_str());
                if (filename.empty())
                        log("%s\n", str.c_str());
                else
@@ -375,18 +458,18 @@ void mutate_inv(Design *design, const mutate_opts_t &opts)
        Module *module = design->module(opts.module);
        Cell *cell = module->cell(opts.cell);
 
-       SigBit bit = cell->getPort(opts.port)[opts.bit];
+       SigBit bit = cell->getPort(opts.port)[opts.portbit];
        SigBit inbit, outbit;
 
        if (cell->input(opts.port))
        {
-               log("Add input inverter at %s.%s.%s[%d].\n", log_id(module), log_id(cell), log_id(opts.port), opts.bit);
+               log("Add input inverter at %s.%s.%s[%d].\n", log_id(module), log_id(cell), log_id(opts.port), opts.portbit);
                SigBit outbit = module->Not(NEW_ID, bit);
                bit = mutate_ctrl_mux(module, opts, bit, outbit);
        }
        else
        {
-               log("Add output inverter at %s.%s.%s[%d].\n", log_id(module), log_id(cell), log_id(opts.port), opts.bit);
+               log("Add output inverter at %s.%s.%s[%d].\n", log_id(module), log_id(cell), log_id(opts.port), opts.portbit);
                SigBit inbit = module->addWire(NEW_ID);
                SigBit outbit = module->Not(NEW_ID, inbit);
                module->connect(bit, mutate_ctrl_mux(module, opts, inbit, outbit));
@@ -394,7 +477,64 @@ void mutate_inv(Design *design, const mutate_opts_t &opts)
        }
 
        SigSpec s = cell->getPort(opts.port);
-       s[opts.bit] = bit;
+       s[opts.portbit] = bit;
+       cell->setPort(opts.port, s);
+}
+
+void mutate_const(Design *design, const mutate_opts_t &opts, bool one)
+{
+       Module *module = design->module(opts.module);
+       Cell *cell = module->cell(opts.cell);
+
+       SigBit bit = cell->getPort(opts.port)[opts.portbit];
+       SigBit inbit, outbit;
+
+       if (cell->input(opts.port))
+       {
+               log("Add input constant %d at %s.%s.%s[%d].\n", one ? 1 : 0, log_id(module), log_id(cell), log_id(opts.port), opts.portbit);
+               SigBit outbit = one ? State::S1 : State::S0;
+               bit = mutate_ctrl_mux(module, opts, bit, outbit);
+       }
+       else
+       {
+               log("Add output constant %d at %s.%s.%s[%d].\n", one ? 1 : 0, log_id(module), log_id(cell), log_id(opts.port), opts.portbit);
+               SigBit inbit = module->addWire(NEW_ID);
+               SigBit outbit = one ? State::S1 : State::S0;
+               module->connect(bit, mutate_ctrl_mux(module, opts, inbit, outbit));
+               bit = inbit;
+       }
+
+       SigSpec s = cell->getPort(opts.port);
+       s[opts.portbit] = bit;
+       cell->setPort(opts.port, s);
+}
+
+void mutate_cnot(Design *design, const mutate_opts_t &opts, bool one)
+{
+       Module *module = design->module(opts.module);
+       Cell *cell = module->cell(opts.cell);
+
+       SigBit bit = cell->getPort(opts.port)[opts.portbit];
+       SigBit ctrl = cell->getPort(opts.port)[opts.ctrlbit];
+       SigBit inbit, outbit;
+
+       if (cell->input(opts.port))
+       {
+               log("Add input cnot%d at %s.%s.%s[%d,%d].\n", one ? 1 : 0, log_id(module), log_id(cell), log_id(opts.port), opts.portbit, opts.ctrlbit);
+               SigBit outbit = one ? module->Xor(NEW_ID, bit, ctrl) : module->Xnor(NEW_ID, bit, ctrl);
+               bit = mutate_ctrl_mux(module, opts, bit, outbit);
+       }
+       else
+       {
+               log("Add output cnot%d at %s.%s.%s[%d,%d].\n", one ? 1 : 0, log_id(module), log_id(cell), log_id(opts.port), opts.portbit, opts.ctrlbit);
+               SigBit inbit = module->addWire(NEW_ID);
+               SigBit outbit = one ? module->Xor(NEW_ID, inbit, ctrl) : module->Xnor(NEW_ID, inbit, ctrl);
+               module->connect(bit, mutate_ctrl_mux(module, opts, inbit, outbit));
+               bit = inbit;
+       }
+
+       SigSpec s = cell->getPort(opts.port);
+       s[opts.portbit] = bit;
        cell->setPort(opts.port, s);
 }
 
@@ -422,7 +562,11 @@ struct MutatePass : public Pass {
                log("    -module name\n");
                log("    -cell name\n");
                log("    -port name\n");
-               log("    -bit int\n");
+               log("    -portbit int\n");
+               log("    -ctrlbit int\n");
+               log("    -wire name\n");
+               log("    -wirebit int\n");
+               log("    -src string\n");
                log("        Filter list of mutation candidates to those matching\n");
                log("        the given parameters.\n");
                log("\n");
@@ -438,9 +582,15 @@ struct MutatePass : public Pass {
                log("    -module name\n");
                log("    -cell name\n");
                log("    -port name\n");
-               log("    -bit int\n");
+               log("    -portbit int\n");
+               log("    -ctrlbit int\n");
                log("        Mutation parameters, as generated by 'mutate -list N'.\n");
                log("\n");
+               log("    -wire name\n");
+               log("    -wirebit int\n");
+               log("    -src string\n");
+               log("        Ignored. (They are generated by -list for documentation purposes.)\n");
+               log("\n");
        }
        void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
        {
@@ -487,8 +637,24 @@ struct MutatePass : public Pass {
                                opts.port = RTLIL::escape_id(args[++argidx]);
                                continue;
                        }
-                       if (args[argidx] == "-bit" && argidx+1 < args.size()) {
-                               opts.bit = atoi(args[++argidx].c_str());
+                       if (args[argidx] == "-portbit" && argidx+1 < args.size()) {
+                               opts.portbit = atoi(args[++argidx].c_str());
+                               continue;
+                       }
+                       if (args[argidx] == "-ctrlbit" && argidx+1 < args.size()) {
+                               opts.ctrlbit = atoi(args[++argidx].c_str());
+                               continue;
+                       }
+                       if (args[argidx] == "-wire" && argidx+1 < args.size()) {
+                               opts.wire = RTLIL::escape_id(args[++argidx]);
+                               continue;
+                       }
+                       if (args[argidx] == "-wirebit" && argidx+1 < args.size()) {
+                               opts.wirebit = atoi(args[++argidx].c_str());
+                               continue;
+                       }
+                       if (args[argidx] == "-src" && argidx+1 < args.size()) {
+                               opts.src.insert(args[++argidx]);
                                continue;
                        }
                        break;
@@ -505,6 +671,16 @@ struct MutatePass : public Pass {
                        return;
                }
 
+               if (opts.mode == "const0" || opts.mode == "const1") {
+                       mutate_const(design, opts, opts.mode == "const1");
+                       return;
+               }
+
+               if (opts.mode == "cnot0" || opts.mode == "cnot1") {
+                       mutate_cnot(design, opts, opts.mode == "cnot1");
+                       return;
+               }
+
                log_cmd_error("Invalid mode: %s\n", opts.mode.c_str());
        }
 } MutatePass;