Support various binary operators in opt_share
authorBogdan Vukobratovic <bogdan.vukobratovic@gmail.com>
Sun, 4 Aug 2019 17:06:38 +0000 (19:06 +0200)
committerBogdan Vukobratovic <bogdan.vukobratovic@gmail.com>
Sun, 4 Aug 2019 17:06:38 +0000 (19:06 +0200)
Makefile
passes/opt/opt_share.cc
tests/opt_share/.gitignore [new file with mode: 0644]
tests/opt_share/generate.py [new file with mode: 0644]
tests/opt_share/run-test.sh [new file with mode: 0755]

index 3bc1198001b1b15f2a420f834c24cb1a6c40864e..d06c7ab3d6700b2758c5c37ca9c5925e4ecd914a 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -678,6 +678,7 @@ test: $(TARGETS) $(EXTRA_TARGETS)
        +cd tests/asicworld && bash run-test.sh $(SEEDOPT)
        # +cd tests/realmath && bash run-test.sh $(SEEDOPT)
        +cd tests/share && bash run-test.sh $(SEEDOPT)
+       +cd tests/opt_share && bash run-test.sh $(SEEDOPT)
        +cd tests/fsm && bash run-test.sh $(SEEDOPT)
        +cd tests/techmap && bash run-test.sh
        +cd tests/memories && bash run-test.sh $(ABCOPT) $(SEEDOPT)
index 25b07cbbd2b9a1e875e6a14192d6b42e6fa9b8bd..e8f44749a8aaa18d1c113282c9849abd7a960d73 100644 (file)
@@ -32,37 +32,36 @@ PRIVATE_NAMESPACE_BEGIN
 
 SigMap assign_map;
 
-struct InPort {
+struct OpMuxConn {
        RTLIL::SigSpec sig;
-       RTLIL::Cell *pmux;
-       int port_id;
-       RTLIL::Cell *alu;
+       RTLIL::Cell *mux;
+       RTLIL::Cell *op;
+       int mux_port_id;
+       int mux_port_offset;
+       int op_outsig_offset;
+
+       bool operator<(const OpMuxConn &other) const
+       {
+               if (mux != other.mux)
+                       return mux < other.mux;
+
+               if (mux_port_id != other.mux_port_id)
+                       return mux_port_id < other.mux_port_id;
 
-       InPort(RTLIL::SigSpec s, RTLIL::Cell *c, int p, RTLIL::Cell *a = NULL) : sig(s), pmux(c), port_id(p), alu(a) {}
+               return mux_port_offset < other.mux_port_offset;
+       }
 };
 
-// Helper class that to track whether a SigSpec is signed and whether it is
-// connected to the \\B port of the $sub cell, which makes its sign prefix
-// negative.
+// Helper class to track additiona information about a SigSpec, like whether it is signed and the semantics of the port it is connected to
 struct ExtSigSpec {
        RTLIL::SigSpec sig;
        RTLIL::SigSpec sign;
        bool is_signed;
+       RTLIL::IdString semantics;
 
        ExtSigSpec() {}
 
-       ExtSigSpec(RTLIL::SigSpec s, bool sign = false, bool is_signed = false) : sig(s), sign(sign), is_signed(is_signed) {}
-
-       ExtSigSpec(RTLIL::Cell *cell, RTLIL::IdString port_name, SigMap *sigmap)
-       {
-               sign = (port_name == "\\B") ? cell->getPort("\\BI") : RTLIL::Const(0, 1);
-               sig = (*sigmap)(cell->getPort(port_name));
-
-               is_signed = false;
-               if (cell->hasParam(port_name.str() + "_SIGNED")) {
-                       is_signed = cell->getParam(port_name.str() + "_SIGNED").as_bool();
-               }
-       }
+       ExtSigSpec(RTLIL::SigSpec s, RTLIL::SigSpec sign = RTLIL::Const(0, 1), bool is_signed = false, RTLIL::IdString semantics = RTLIL::IdString()) : sig(s), sign(sign), is_signed(is_signed), semantics(semantics) {}
 
        bool empty() const { return sig.empty(); }
 
@@ -74,42 +73,136 @@ struct ExtSigSpec {
                if (sign != other.sign)
                        return sign < other.sign;
 
-               return is_signed < other.is_signed;
+               if (is_signed != other.is_signed)
+                       return is_signed < other.is_signed;
+
+               return semantics < other.semantics;
        }
 
        bool operator==(const RTLIL::SigSpec &other) const { return (sign != RTLIL::Const(0, 1)) ? false : sig == other; }
-       bool operator==(const ExtSigSpec &other) const { return is_signed == other.is_signed && sign == other.sign && sig == other.sig; }
+       bool operator==(const ExtSigSpec &other) const { return is_signed == other.is_signed && sign == other.sign && sig == other.sig && semantics == other.semantics; }
 };
 
-void merge_operators(RTLIL::Module *module, RTLIL::Cell *mux, const std::vector<InPort> &ports, int offset, int width,
-                    const ExtSigSpec &operand)
+#define BITWISE_OPS "$_AND_", "$_NAND_", "$_OR_", "$_NOR_", "$_XOR_", "$_XNOR_", "$_ANDNOT_", "$_ORNOT_", "$and", "$or", "$xor", "$xnor"
+
+#define REDUCTION_OPS "$reduce_and", "$reduce_or", "$reduce_xor", "$reduce_xnor", "$reduce_bool", "$reduce_nand"
+
+#define LOGICAL_OPS "$logic_and", "$logic_or"
+
+#define SHIFT_OPS "$shl", "$shr", "$sshl", "$sshr", "$shift", "$shiftx"
+
+#define RELATIONAL_OPS "$lt", "$le", "$eq", "$ne", "$eqx", "$nex", "$ge", "$gt"
+
+bool cell_supported(RTLIL::Cell *cell)
+{
+
+       if (cell->type.in("$alu")) {
+               RTLIL::SigSpec sig_bi = cell->getPort("\\BI");
+               RTLIL::SigSpec sig_ci = cell->getPort("\\CI");
+
+               if (sig_bi.is_fully_const() && sig_ci.is_fully_const() && sig_bi == sig_ci)
+                       return true;
+       } else if (cell->type.in(LOGICAL_OPS, SHIFT_OPS, BITWISE_OPS, RELATIONAL_OPS, "$add", "$sub", "$mul", "$div", "$mod", "$concat")) {
+               return true;
+       }
+
+       return false;
+}
+
+std::map<std::string, std::string> mergeable_type_map{
+  {"$sub", "$add"},
+};
+
+bool mergeable(RTLIL::Cell *a, RTLIL::Cell *b)
+{
+       auto a_type = a->type;
+       if (mergeable_type_map.count(a_type.str()))
+               a_type = mergeable_type_map.at(a_type.str());
+
+       auto b_type = b->type;
+       if (mergeable_type_map.count(b_type.str()))
+               b_type = mergeable_type_map.at(b_type.str());
+
+       return a_type == b_type;
+}
+
+RTLIL::IdString decode_port_semantics(RTLIL::Cell *cell, RTLIL::IdString port_name)
+{
+       if (cell->type.in("$lt", "$le", "$ge", "$gt", "$div", "$mod", "$concat", SHIFT_OPS) && port_name == "\\B")
+               return port_name;
+
+       return "";
+}
+
+RTLIL::SigSpec decode_port_sign(RTLIL::Cell *cell, RTLIL::IdString port_name) {
+
+       if (cell->type == "$alu" && port_name == "\\B")
+               return cell->getPort("\\BI");
+       else if (cell->type == "$sub" && port_name == "\\B")
+               return RTLIL::Const(1, 1);
+
+       return RTLIL::Const(0, 1);
+}
+
+bool decode_port_signed(RTLIL::Cell *cell, RTLIL::IdString port_name)
+{
+       if (cell->type.in(BITWISE_OPS, LOGICAL_OPS))
+               return false;
+
+       if (cell->hasParam(port_name.str() + "_SIGNED"))
+               return cell->getParam(port_name.str() + "_SIGNED").as_bool();
+
+       return false;
+
+}
+
+ExtSigSpec decode_port(RTLIL::Cell *cell, RTLIL::IdString port_name, SigMap *sigmap)
+{
+       auto sig = (*sigmap)(cell->getPort(port_name));
+
+       RTLIL::SigSpec sign = decode_port_sign(cell, port_name);
+       RTLIL::IdString semantics = decode_port_semantics(cell, port_name);
+
+       bool is_signed = decode_port_signed(cell, port_name);
+
+       return ExtSigSpec(sig, sign, is_signed, semantics);
+}
+
+void merge_operators(RTLIL::Module *module, RTLIL::Cell *mux, const std::vector<OpMuxConn> &ports, const ExtSigSpec &operand)
 {
 
        std::vector<ExtSigSpec> muxed_operands;
        int max_width = 0;
        for (const auto& p : ports) {
-               auto op = p.alu;
+               auto op = p.op;
 
-               for (RTLIL::IdString port_name : {"\\A", "\\B"}) {
-                       if (op->getPort(port_name) != operand.sig) {
-                               auto operand = ExtSigSpec(op, port_name, &assign_map);
-                               if (operand.sig.size() > max_width) {
-                                       max_width = operand.sig.size();
-                               }
+               RTLIL::IdString muxed_port_name = "\\A";
+               if (op->getPort("\\A") == operand.sig) {
+                       muxed_port_name = "\\B";
+               }
 
-                               muxed_operands.push_back(operand);
-                       }
+               auto operand = decode_port(op, muxed_port_name, &assign_map);
+               if (operand.sig.size() > max_width) {
+                       max_width = operand.sig.size();
                }
+
+               muxed_operands.push_back(operand);
        }
 
+       auto shared_op = ports[0].op;
+
+       if (std::any_of(muxed_operands.begin(), muxed_operands.end(), [&](ExtSigSpec &op) { return op.sign != muxed_operands[0].sign; }))
+               if (max_width < shared_op->getParam("\\Y_WIDTH").as_int())
+                       max_width = shared_op->getParam("\\Y_WIDTH").as_int();
+
+
        for (auto &operand : muxed_operands) {
                operand.sig.extend_u0(max_width, operand.is_signed);
        }
 
-       auto shared_op = ports[0].alu;
 
        for (const auto& p : ports) {
-               auto op = p.alu;
+               auto op = p.op;
                if (op == shared_op)
                        continue;
                module->remove(op);
@@ -126,40 +219,47 @@ void merge_operators(RTLIL::Module *module, RTLIL::Cell *mux, const std::vector<
        RTLIL::SigSpec mux_b = mux->getPort("\\B");
        RTLIL::SigSpec mux_s = mux->getPort("\\S");
 
-       RTLIL::SigSpec alu_x = shared_op->getPort("\\X");
-       RTLIL::SigSpec alu_co = shared_op->getPort("\\CO");
-
        RTLIL::SigSpec shared_pmux_a = RTLIL::Const(RTLIL::State::Sx, max_width);
        RTLIL::SigSpec shared_pmux_b;
        RTLIL::SigSpec shared_pmux_s;
 
-       shared_op->setPort("\\Y", shared_op->getPort("\\Y").extract(0, width));
+       int conn_width = ports[0].sig.size();
+       int conn_offset = ports[0].mux_port_offset;
+
+       shared_op->setPort("\\Y", shared_op->getPort("\\Y").extract(0, conn_width));
 
        if (mux->type == "$pmux") {
                shared_pmux_s = RTLIL::SigSpec();
 
-               for (const auto&p: ports) {
-                       shared_pmux_s.append(mux_s[p.port_id]);
-                       mux_b.replace(p.port_id * mux_a.size() + offset, shared_op->getPort("\\Y"));
+               for (const auto &p : ports) {
+                       shared_pmux_s.append(mux_s[p.mux_port_id]);
+                       mux_b.replace(p.mux_port_id * mux_a.size() + conn_offset, shared_op->getPort("\\Y"));
                }
        } else {
                shared_pmux_s = RTLIL::SigSpec{mux_s, module->Not(NEW_ID, mux_s)};
-               mux_a.replace(offset, shared_op->getPort("\\Y"));
-               mux_b.replace(offset, shared_op->getPort("\\Y"));
+               mux_a.replace(conn_offset, shared_op->getPort("\\Y"));
+               mux_b.replace(conn_offset, shared_op->getPort("\\Y"));
        }
 
+       mux->setPort("\\A", mux_a);
+       mux->setPort("\\B", mux_b);
        mux->setPort("\\Y", mux_y);
        mux->setPort("\\S", mux_s);
-       mux->setPort("\\B", mux_b);
 
        for (const auto &op : muxed_operands)
                shared_pmux_b.append(op.sig);
 
        auto mux_to_oper = module->Pmux(NEW_ID, shared_pmux_a, shared_pmux_b, shared_pmux_s);
 
-       shared_op->setPort("\\X", alu_x.extract(0, width));
-       shared_op->setPort("\\CO", alu_co.extract(0, width));
-       shared_op->setParam("\\Y_WIDTH", width);
+       if (shared_op->type.in("$alu")) {
+               RTLIL::SigSpec alu_x = shared_op->getPort("\\X");
+               RTLIL::SigSpec alu_co = shared_op->getPort("\\CO");
+
+               shared_op->setPort("\\X", alu_x.extract(0, conn_width));
+               shared_op->setPort("\\CO", alu_co.extract(0, conn_width));
+       }
+
+       shared_op->setParam("\\Y_WIDTH", conn_width);
 
        if (shared_op->getPort("\\A") == operand.sig) {
                shared_op->setPort("\\B", mux_to_oper);
@@ -173,11 +273,9 @@ void merge_operators(RTLIL::Module *module, RTLIL::Cell *mux, const std::vector<
 
 typedef struct {
        RTLIL::Cell *mux;
-       std::vector<InPort> ports;
-       int offset;
-       int width;
+       std::vector<OpMuxConn> ports;
        ExtSigSpec shared_operand;
-} shared_op_t;
+} merged_op_t;
 
 
 template <typename T> void remove_val(std::vector<T> &v, const std::vector<T> &vals)
@@ -190,86 +288,60 @@ template <typename T> void remove_val(std::vector<T> &v, const std::vector<T> &v
                }
 }
 
-bool find_op_res_width(int offset, int &width, std::vector<InPort*>& ports, const dict<RTLIL::SigBit, RTLIL::SigSpec> &op_outbit_to_outsig)
+void check_muxed_operands(std::vector<const OpMuxConn *> &ports, const ExtSigSpec &shared_operand)
 {
 
-       std::vector<RTLIL::SigSpec> op_outsigs;
-       dict<int, std::set<InPort*>> op_outsig_span;
-
-       std::transform(ports.begin(), ports.end(), std::back_inserter(op_outsigs), [&](InPort *p) { return op_outbit_to_outsig.at(p->sig[offset]); });
-
-       std::vector<bool> finished(ports.size(), false);
+       auto it = ports.begin();
+       ExtSigSpec seed;
 
-       width = 0;
+       while (it != ports.end()) {
+               auto p = *it;
+               auto op = p->op;
 
-       std::function<bool()> all_finished = [&] { return std::find(std::begin(finished), std::end(finished), false) == end(finished);};
-
-       while (!all_finished())
-       {
-               ++offset;
-               ++width;
-
-               if (offset >= ports[0]->sig.size()) {
-                       for (size_t i = 0; i < op_outsigs.size(); ++i) {
-                               if (finished[i])
-                                       continue;
-
-                               op_outsig_span[width].insert(ports[i]);
-                               finished[i] = true;
-                       }
-
-                       break;
+               RTLIL::IdString muxed_port_name = "\\A";
+               if (op->getPort("\\A") == shared_operand.sig) {
+                       muxed_port_name = "\\B";
                }
 
-               for (size_t i = 0; i < op_outsigs.size(); ++i) {
-                       if (finished[i])
-                               continue;
+               auto operand = decode_port(op, muxed_port_name, &assign_map);
 
-                       if ((width >= op_outsigs[i].size()) || (ports[i]->sig[offset] != op_outsigs[i][width])) {
-                               op_outsig_span[width].insert(ports[i]);
-                               finished[i] = true;
-                       }
-               }
-       }
-
-       for (auto w: op_outsig_span) {
-               if (w.second.size() > 1) {
-                       width = w.first;
+               if (seed.empty())
+                       seed = operand;
 
-                       ports.erase(std::remove_if(ports.begin(), ports.end(), [&](InPort *p) { return !w.second.count(p); }), ports.end());
-
-                       return true;
+               if (operand.is_signed != seed.is_signed) {
+                       ports.erase(it);
+               } else {
+                       ++it;
                }
        }
-
-       return false;
 }
 
-ExtSigSpec find_shared_operand(InPort* seed, std::vector<InPort *> &ports, const std::map<ExtSigSpec, std::set<RTLIL::Cell *>> &operand_to_users)
+ExtSigSpec find_shared_operand(const OpMuxConn* seed, std::vector<const OpMuxConn *> &ports, const std::map<ExtSigSpec, std::set<RTLIL::Cell *>> &operand_to_users)
 {
-       std::set<RTLIL::Cell *> alus_using_operand;
-       std::set<RTLIL::Cell *> alus_set;
+       std::set<RTLIL::Cell *> ops_using_operand;
+       std::set<RTLIL::Cell *> ops_set;
        for(const auto& p: ports)
-               alus_set.insert(p->alu);
+               ops_set.insert(p->op);
 
        ExtSigSpec oper;
 
-       auto op_a = seed->alu;
+       auto op_a = seed->op;
 
        for (RTLIL::IdString port_name : {"\\A", "\\B"}) {
-               oper = ExtSigSpec(op_a, port_name, &assign_map);
+               oper = decode_port(op_a, port_name, &assign_map);
                auto operand_users = operand_to_users.at(oper);
 
                if (operand_users.size() == 1)
                        continue;
 
-               alus_using_operand.clear();
-               std::set_intersection(operand_users.begin(), operand_users.end(), alus_set.begin(), alus_set.end(),
-                                     std::inserter(alus_using_operand, alus_using_operand.begin()));
+               ops_using_operand.clear();
+               for (auto mux_ops: ops_set)
+                       if (operand_users.count(mux_ops))
+                               ops_using_operand.insert(mux_ops);
 
-               if (alus_using_operand.size() > 1) {
-                       ports.erase(std::remove_if(ports.begin(), ports.end(), [&](InPort *p) { return !alus_using_operand.count(p->alu); }),
-                                   ports.end());
+               if (ops_using_operand.size() > 1) {
+                       ports.erase(std::remove_if(ports.begin(), ports.end(), [&](const OpMuxConn *p) { return !ops_using_operand.count(p->op); }),
+                                               ports.end());
                        return oper;
                }
        }
@@ -277,40 +349,135 @@ ExtSigSpec find_shared_operand(InPort* seed, std::vector<InPort *> &ports, const
        return ExtSigSpec();
 }
 
-void remove_multi_user_outbits(RTLIL::Module *module, dict<RTLIL::SigBit, RTLIL::SigSpec> &op_outbit_to_outsig)
+dict<RTLIL::SigSpec, OpMuxConn> find_valid_op_mux_conns(RTLIL::Module *module, dict<RTLIL::SigBit, RTLIL::SigSpec> &op_outbit_to_outsig,
+                                                       dict<RTLIL::SigSpec, RTLIL::Cell *> outsig_to_operator,
+                                                       dict<RTLIL::SigBit, RTLIL::SigSpec> &op_aux_to_outsig)
 {
-       dict<RTLIL::SigBit, int> op_outbit_user_cnt;
+       dict<RTLIL::SigSpec, int> op_outsig_user_track;
+       dict<RTLIL::SigSpec, OpMuxConn> op_mux_conn_map;
 
-       std::function<void(SigSpec)> update_op_outbit_user_cnt = [&](SigSpec sig) {
-               auto outsig = assign_map(sig);
-               for (auto outbit : outsig) {
-                       if (!op_outbit_to_outsig.count(outbit))
-                               continue;
+       std::function<void(RTLIL::SigSpec)> remove_outsig = [&](RTLIL::SigSpec outsig) {
+               for (auto op_outbit : outsig)
+                       op_outbit_to_outsig.erase(op_outbit);
+
+               if (op_mux_conn_map.count(outsig))
+                       op_mux_conn_map.erase(outsig);
+       };
 
-                       if (++op_outbit_user_cnt[outbit] > 1) {
-                               auto alu_outsig = op_outbit_to_outsig.at(outbit);
+       std::function<void(RTLIL::SigBit)> remove_outsig_from_aux_bit = [&](RTLIL::SigBit auxbit) {
+               auto aux_outsig = op_aux_to_outsig.at(auxbit);
+               auto op = outsig_to_operator.at(aux_outsig);
+               auto op_outsig = assign_map(op->getPort("\\Y"));
+               remove_outsig(op_outsig);
 
-                               for (auto outbit : alu_outsig)
-                                       op_outbit_to_outsig.erase(outbit);
+               for (auto aux_outbit : aux_outsig)
+                       op_aux_to_outsig.erase(aux_outbit);
+       };
+
+       std::function<void(RTLIL::Cell *)>
+         find_op_mux_conns = [&](RTLIL::Cell *mux) {
+                 RTLIL::SigSpec sig;
+                 int mux_port_size;
+
+                 if (mux->type.in("$mux", "$_MUX_")) {
+                         mux_port_size = mux->getPort("\\A").size();
+                         sig = RTLIL::SigSpec{mux->getPort("\\B"), mux->getPort("\\A")};
+                 } else {
+                         mux_port_size = mux->getPort("\\A").size();
+                         sig = mux->getPort("\\B");
+                 }
+
+                 auto mux_insig = assign_map(sig);
+
+                 for (int i = 0; i < mux_insig.size(); ++i) {
+                         if (op_aux_to_outsig.count(mux_insig[i])) {
+                                 remove_outsig_from_aux_bit(mux_insig[i]);
+                                 continue;
+                         }
+
+                         if (!op_outbit_to_outsig.count(mux_insig[i]))
+                                 continue;
+
+                         auto op_outsig = op_outbit_to_outsig.at(mux_insig[i]);
+
+                         if (op_mux_conn_map.count(op_outsig)) {
+                                       remove_outsig(op_outsig);
+                                 continue;
+                         }
+
+                         int mux_port_id = i / mux_port_size;
+                         int mux_port_offset = i % mux_port_size;
+
+                         int op_outsig_offset;
+                         for (op_outsig_offset = 0; op_outsig[op_outsig_offset] != mux_insig[i]; ++op_outsig_offset)
+                                 ;
+
+                         int j = op_outsig_offset;
+                         do {
+                                 if (!op_outbit_to_outsig.count(mux_insig[i]))
+                                         break;
+
+                                 if (op_outbit_to_outsig.at(mux_insig[i]) != op_outsig)
+                                         break;
+
+                                 ++i;
+                                 ++j;
+                         } while ((i / mux_port_size == mux_port_id) && (j < op_outsig.size()));
+
+                         int op_conn_width = j - op_outsig_offset;
+                         OpMuxConn inp = {
+                           op_outsig.extract(op_outsig_offset, op_conn_width),
+                           mux,
+                           outsig_to_operator.at(op_outsig),
+                           mux_port_id,
+                           mux_port_offset,
+                           op_outsig_offset,
+                         };
+
+                         op_mux_conn_map[op_outsig] = inp;
+
+                         --i;
+                 }
+         };
+
+       std::function<void(RTLIL::SigSpec)> remove_connected_ops = [&](RTLIL::SigSpec sig) {
+               auto mux_insig = assign_map(sig);
+               for (auto outbit : mux_insig) {
+                       if (op_aux_to_outsig.count(outbit)) {
+                               remove_outsig_from_aux_bit(outbit);
+                               continue;
                        }
+
+                       if (!op_outbit_to_outsig.count(outbit))
+                               continue;
+
+                       remove_outsig(op_outbit_to_outsig.at(outbit));
                }
        };
 
-       for (auto cell : module->cells())
-               for (auto &conn : cell->connections())
-                       if (cell->input(conn.first))
-                               update_op_outbit_user_cnt(conn.second);
+       for (auto cell : module->cells()) {
+               if (cell->type.in("$mux", "$_MUX_", "$pmux")) {
+                       remove_connected_ops(cell->getPort("\\S"));
+                       find_op_mux_conns(cell);
+               } else {
+                       for (auto &conn : cell->connections())
+                               if (cell->input(conn.first))
+                                       remove_connected_ops(conn.second);
+               }
+       }
 
        for (auto w : module->wires()) {
                if (!w->port_output)
                        continue;
 
-               update_op_outbit_user_cnt(w);
+               remove_connected_ops(w);
        }
+
+       return op_mux_conn_map;
 }
 
 struct OptSharePass : public Pass {
-       OptSharePass() : Pass("opt_share", "merge arithmetic operators that share an operand") {}
+       OptSharePass() : Pass("opt_share", "merge mutually exclusive cells of the same type that share an input signal") {}
        void help() YS_OVERRIDE
        {
                //   |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
@@ -318,18 +485,19 @@ struct OptSharePass : public Pass {
                log("    opt_share [selection]\n");
                log("\n");
 
-               log("This pass identifies mutually exclusive $alu arithmetic cells that:\n");
-               log("    (a) share an input operand\n");
+               log("This pass identifies mutually exclusive cells of the same type that:\n");
+               log("    (a) share an input signal\n");
                log("    (b) drive the same $mux, $_MUX_, or $pmux multiplexing cell allowing\n");
-               log("        the $alu cell to be merged and the multiplexer to be moved from\n");
-               log("        multiplexing its output to multiplexing the non-shared input operands.\n");
+               log("        the cell to be merged and the multiplexer to be moved from\n");
+               log("        multiplexing its output to multiplexing the non-shared input signals.\n");
                log("\n");
        }
-       void execute(std::vector<std::string>, RTLIL::Design *design) YS_OVERRIDE
+       void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
        {
 
                log_header(design, "Executing OPT_SHARE pass.\n");
 
+               extra_args(args, 1, design);
                for (auto module : design->selected_modules()) {
                        assign_map.clear();
                        assign_map.set(module);
@@ -337,28 +505,30 @@ struct OptSharePass : public Pass {
                        std::map<ExtSigSpec, std::set<RTLIL::Cell *>> operand_to_users;
                        dict<RTLIL::SigSpec, RTLIL::Cell *> outsig_to_operator;
                        dict<RTLIL::SigBit, RTLIL::SigSpec> op_outbit_to_outsig;
+                       dict<RTLIL::SigBit, RTLIL::SigSpec> op_aux_to_outsig;
                        bool any_shared_operands = false;
                        std::vector<ExtSigSpec> op_insigs;
 
                        for (auto cell : module->cells()) {
-                               if (!cell->type.in("$alu"))
+                               if (!cell_supported(cell))
                                        continue;
 
-                               RTLIL::SigSpec sig_bi = cell->getPort("\\BI");
-                               RTLIL::SigSpec sig_ci = cell->getPort("\\CI");
-
-                               if ((!sig_bi.is_fully_const()) || (!sig_ci.is_fully_const()) || (sig_bi != sig_ci))
-                                       continue;
-
-                               RTLIL::SigSpec sig_y = cell->getPort("\\A");
+                               if (cell->type == "$alu") {
+                                       for (RTLIL::IdString port_name : {"\\X", "\\CO"}) {
+                                               auto mux_insig = assign_map(cell->getPort(port_name));
+                                               outsig_to_operator[mux_insig] = cell;
+                                               for (auto outbit : mux_insig)
+                                                       op_aux_to_outsig[outbit] = mux_insig;
+                                       }
+                               }
 
-                               auto outsig = assign_map(cell->getPort("\\Y"));
-                               outsig_to_operator[outsig] = cell;
-                               for (auto outbit : outsig)
-                                       op_outbit_to_outsig[outbit] = outsig;
+                               auto mux_insig = assign_map(cell->getPort("\\Y"));
+                               outsig_to_operator[mux_insig] = cell;
+                               for (auto outbit : mux_insig)
+                                       op_outbit_to_outsig[outbit] = mux_insig;
 
                                for (RTLIL::IdString port_name : {"\\A", "\\B"}) {
-                                       auto op_insig = ExtSigSpec(cell, port_name, &assign_map);
+                                       auto op_insig = decode_port(cell, port_name, &assign_map);
                                        op_insigs.push_back(op_insig);
                                        operand_to_users[op_insig].insert(cell);
                                        if (operand_to_users[op_insig].size() > 1)
@@ -371,89 +541,117 @@ struct OptSharePass : public Pass {
 
                        // Operator outputs need to be exclusively connected to the $mux inputs in order to be mergeable. Hence we count to
                        // how many points are operator output bits connected.
-                       remove_multi_user_outbits(module, op_outbit_to_outsig);
+                       dict<RTLIL::SigSpec, OpMuxConn> op_mux_conn_map =
+                         find_valid_op_mux_conns(module, op_outbit_to_outsig, outsig_to_operator, op_aux_to_outsig);
 
-                       std::vector<shared_op_t> shared_ops;
-                       for (auto cell : module->cells()) {
-                               if (!cell->type.in("$mux", "$_MUX_", "$pmux"))
-                                       continue;
+                       // Group op connections connected to same ports of the same $mux. Sort them in ascending order of their port offset
+                       dict<RTLIL::Cell*, std::vector<std::set<OpMuxConn>>> mux_port_op_conns;
+                       for (auto& val: op_mux_conn_map) {
+                               OpMuxConn p = val.second;
+                               auto& mux_port_conns = mux_port_op_conns[p.mux];
 
-                               RTLIL::SigSpec sig_a = cell->getPort("\\A");
-                               RTLIL::SigSpec sig_b = cell->getPort("\\B");
-                               RTLIL::SigSpec sig_s = cell->getPort("\\S");
+                               if (mux_port_conns.size() == 0) {
+                                       int mux_port_num;
 
-                               std::vector<InPort> ports;
+                                       if (p.mux->type.in("$mux", "$_MUX_"))
+                                               mux_port_num = 2;
+                                       else
+                                               mux_port_num = p.mux->getPort("\\S").size();
 
-                               if (cell->type.in("$mux", "$_MUX_")) {
-                                       ports.push_back(InPort(assign_map(sig_a), cell, 0));
-                                       ports.push_back(InPort(assign_map(sig_b), cell, 1));
-                               } else {
-                                       RTLIL::SigSpec sig_s = cell->getPort("\\S");
-                                       for (int i = 0; i < sig_s.size(); i++) {
-                                               auto inp = sig_b.extract(i * sig_a.size(), sig_a.size());
-                                               ports.push_back(InPort(assign_map(inp), cell, i));
-                                       }
+                                       mux_port_conns.resize(mux_port_num);
                                }
 
+                               mux_port_conns[p.mux_port_id].insert(p);
+                       }
+
+                       std::vector<merged_op_t> merged_ops;
+                       for (auto& val: mux_port_op_conns) {
+
+                               RTLIL::Cell* cell = val.first;
+                               auto &mux_port_conns = val.second;
+
+                               const OpMuxConn *seed = NULL;
+
                                // Look through the bits of the $mux inputs and see which of them are connected to the operator
                                // results. Operator results can be concatenated with other signals before led to the $mux.
-                               for (int i = 0; i < sig_a.size(); ++i) {
-                                       std::vector<InPort*> alu_ports;
-                                       for (auto& p: ports)
-                                               if (op_outbit_to_outsig.count(p.sig[i])) {
-                                                       p.alu = outsig_to_operator.at(op_outbit_to_outsig.at(p.sig[i]));
-                                                       alu_ports.push_back(&p);
-                                               }
+                               while (true) {
 
-                                       int alu_port_width = 0;
+                                       // Remove either the merged ports from the last iteration or the seed that failed to yield a merger
+                                       if (seed != NULL) {
+                                               mux_port_conns[seed->mux_port_id].erase(*seed);
+                                               seed = NULL;
+                                       }
 
-                                       while (alu_ports.size() > 1) {
-                                               std::vector<InPort*> shared_ports(alu_ports);
+                                       // For a new merger, find the seed op connection that starts at lowest port offset among port connections
+                                       for (auto &port_conns : mux_port_conns) {
+                                               if (!port_conns.size())
+                                                       continue;
 
-                                               auto seed = alu_ports[0];
-                                               alu_ports.erase(alu_ports.begin());
+                                               const OpMuxConn *next_p = &(*port_conns.begin());
 
-                                               // Find ports whose $alu-s share an operand with $alu connected to the seed port
-                                               auto shared_operand = find_shared_operand(seed, shared_ports, operand_to_users);
+                                               if ((seed == NULL) || (seed->mux_port_offset > next_p->mux_port_offset))
+                                                       seed = next_p;
+                                       }
 
-                                               if (shared_operand.empty())
+                                       // Cannot find the seed -> nothing to do for this $mux anymore
+                                       if (seed == NULL)
+                                               break;
+
+                                       // Find all other op connections that start from the same port offset, and whose ops can be merged with the seed op
+                                       std::vector<const OpMuxConn *> mergeable_conns;
+                                       for (auto &port_conns : mux_port_conns) {
+                                               if (!port_conns.size())
                                                        continue;
 
-                                               // Some bits of the operator results might be unconnected. Calculate the number of conneted
-                                               // bits.
-                                               if (!find_op_res_width(i, alu_port_width, shared_ports, op_outbit_to_outsig))
-                                                       break;
+                                               const OpMuxConn *next_p = &(*port_conns.begin());
+
+                                               if ((next_p->op_outsig_offset == seed->op_outsig_offset) &&
+                                                   (next_p->mux_port_offset == seed->mux_port_offset) && mergeable(next_p->op, seed->op) &&
+                                                   next_p->sig.size() == seed->sig.size())
+                                                       mergeable_conns.push_back(next_p);
+                                       }
+
+                                       // We need at least two mergeable connections for the merger
+                                       if (mergeable_conns.size() < 2)
+                                               continue;
 
-                                               if (shared_ports.size() < 2)
-                                                       break;
+                                       // Filter mergeable connections whose ops share an operand with seed connection's op
+                                       auto shared_operand = find_shared_operand(seed, mergeable_conns, operand_to_users);
 
-                                               // Remember the combination for the merger
-                                               std::vector<InPort> shared_p;
-                                               for (auto p: shared_ports)
-                                                       shared_p.push_back(*p);
+                                       if (shared_operand.empty())
+                                               continue;
 
-                                               shared_ops.push_back(shared_op_t{cell, shared_p, i, alu_port_width, shared_operand});
+                                       check_muxed_operands(mergeable_conns, shared_operand);
 
-                                               // Remove merged ports from the list and try to find other mergers for the mux
-                                               remove_val(alu_ports, shared_ports);
+                                       if (mergeable_conns.size() < 2)
+                                               continue;
+
+                                       // Remember the combination for the merger
+                                       std::vector<OpMuxConn> merged_ports;
+                                       for (auto p : mergeable_conns) {
+                                               merged_ports.push_back(*p);
+                                               mux_port_conns[p->mux_port_id].erase(*p);
                                        }
 
-                                       if (alu_port_width)
-                                               i += alu_port_width - 1;
+                                       seed = NULL;
+
+                                       merged_ops.push_back(merged_op_t{cell, merged_ports, shared_operand});
+
+                                       design->scratchpad_set_bool("opt.did_something", true);
                                }
 
                        }
 
-                       for (auto &shared : shared_ops) {
-                               log("    Found arithmetic cells that share an operand and can be merged by moving the %s %s in front "
+                       for (auto &shared : merged_ops) {
+                               log("    Found cells that share an operand and can be merged by moving the %s %s in front "
                                    "of "
                                    "them:\n",
                                    log_id(shared.mux->type), log_id(shared.mux));
                                for (const auto& op : shared.ports)
-                                       log("        %s\n", log_id(op.alu));
+                                       log("        %s\n", log_id(op.op));
                                log("\n");
 
-                               merge_operators(module, shared.mux, shared.ports, shared.offset, shared.width, shared.shared_operand);
+                               merge_operators(module, shared.mux, shared.ports, shared.shared_operand);
                        }
                }
        }
diff --git a/tests/opt_share/.gitignore b/tests/opt_share/.gitignore
new file mode 100644 (file)
index 0000000..9c595a6
--- /dev/null
@@ -0,0 +1 @@
+temp
diff --git a/tests/opt_share/generate.py b/tests/opt_share/generate.py
new file mode 100644 (file)
index 0000000..2ec92f7
--- /dev/null
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+
+import argparse
+import sys
+import random
+from contextlib import contextmanager
+
+
+@contextmanager
+def redirect_stdout(new_target):
+    old_target, sys.stdout = sys.stdout, new_target
+    try:
+        yield new_target
+    finally:
+        sys.stdout = old_target
+
+
+def random_plus_x():
+    return "%s x" % random.choice(['+', '+', '+', '-', '-', '|', '&', '^'])
+
+
+def maybe_plus_x(expr):
+    if random.randint(0, 4) == 0:
+        return "(%s %s)" % (expr, random_plus_x())
+    else:
+        return expr
+
+
+parser = argparse.ArgumentParser(
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('-S', '--seed', type=int, help='seed for PRNG')
+parser.add_argument('-c',
+                    '--count',
+                    type=int,
+                    default=100,
+                    help='number of test cases to generate')
+args = parser.parse_args()
+
+if args.seed is not None:
+    print("PRNG seed: %d" % args.seed)
+    random.seed(args.seed)
+
+for idx in range(args.count):
+    with open('temp/uut_%05d.v' % idx, 'w') as f:
+        with redirect_stdout(f):
+            print('module uut_%05d(a, b, c, s, y);' % (idx))
+            op = random.choice([
+                random.choice(['+', '-', '*', '/', '%']),
+                random.choice(['<', '<=', '==', '!=', '===', '!==', '>=',
+                               '>']),
+                random.choice(['<<', '>>', '<<<', '>>>']),
+                random.choice(['|', '&', '^', '~^', '||', '&&']),
+            ])
+            print('  input%s [%d:0] a;' % (random.choice(['', ' signed']), 8))
+            print('  input%s [%d:0] b;' % (random.choice(['', ' signed']), 8))
+            print('  input%s [%d:0] c;' % (random.choice(['', ' signed']), 8))
+            print('  input s;')
+            print('  output [%d:0] y;' % 8)
+            ops1 = ['a', 'b']
+            ops2 = ['a', 'c']
+            random.shuffle(ops1)
+            random.shuffle(ops2)
+            cast1 = random.choice(['', '$signed', '$unsigned'])
+            cast2 = random.choice(['', '$signed', '$unsigned'])
+            print('  assign y = (s ? %s(%s %s %s) : %s(%s %s %s));' %
+                  (cast1, ops1[0], op, ops1[1],
+                   cast2, ops2[0], op, ops2[1]))
+            print('endmodule')
+
+    with open('temp/uut_%05d.ys' % idx, 'w') as f:
+        with redirect_stdout(f):
+            print('read_verilog temp/uut_%05d.v' % idx)
+            print('proc;;')
+            print('copy uut_%05d gold' % idx)
+            print('rename uut_%05d gate' % idx)
+            print('tee -a temp/all_share_log.txt log')
+            print('tee -a temp/all_share_log.txt log #job# uut_%05d' % idx)
+            print('tee -a temp/all_share_log.txt opt gate')
+            print('tee -a temp/all_share_log.txt opt_share gate')
+            print('tee -a temp/all_share_log.txt opt_clean gate')
+            print(
+                'miter -equiv -flatten -ignore_gold_x -make_outputs -make_outcmp gold gate miter'
+            )
+            print(
+                'sat -set-def-inputs -verify -prove trigger 0 -show-inputs -show-outputs miter'
+            )
diff --git a/tests/opt_share/run-test.sh b/tests/opt_share/run-test.sh
new file mode 100755 (executable)
index 0000000..e015526
--- /dev/null
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# run this test many times:
+# time bash -c 'for ((i=0; i<100; i++)); do echo "-- $i --"; bash run-test.sh || exit 1; done'
+
+set -e
+
+OPTIND=1
+count=100
+seed=""    # default to no seed specified
+while getopts "c:S:" opt
+do
+  case "$opt" in
+               c) count="$OPTARG" ;;
+               S) seed="-S $OPTARG" ;;
+  esac
+done
+shift "$((OPTIND-1))"
+
+rm -rf temp
+mkdir -p temp
+echo "generating tests.."
+python3 generate.py -c $count $seed
+
+echo "running tests.."
+for i in $( ls temp/*.ys | sed 's,[^0-9],,g; s,^0*\(.\),\1,g;' ); do
+       echo -n "[$i]"
+       idx=$( printf "%05d" $i )
+       ../../yosys -ql temp/uut_${idx}.log temp/uut_${idx}.ys
+done
+echo
+
+failed_share=$( echo $( gawk '/^#job#/ { j=$2; db[j]=0; } /^Removing [246] cells/ { delete db[j]; } END { for (j in db) print(j); }' temp/all_share_log.txt ) )
+if [ -n "$failed_share" ]; then
+       echo "Resource sharing failed for the following test cases: $failed_share"
+       false
+fi
+
+exit 0