#include "kernel/yosys.h"
 #include "kernel/sigtools.h"
+#include "kernel/ffinit.h"
+#include "kernel/ff.h"
 
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
                for (auto module : design->selected_modules())
                {
                        SigMap sigmap(module);
-                       dict<SigBit, State> initbits;
-                       pool<SigBit> del_initbits;
-
-                       for (auto wire : module->wires())
-                               if (wire->attributes.count(ID::init) > 0)
-                               {
-                                       Const initval = wire->attributes.at(ID::init);
-                                       SigSpec initsig = sigmap(wire);
-
-                                       for (int i = 0; i < GetSize(initval) && i < GetSize(initsig); i++)
-                                               if (initval[i] == State::S0 || initval[i] == State::S1)
-                                                       initbits[initsig[i]] = initval[i];
-                               }
+                       FfInitVals initvals(&sigmap, module);
 
                        for (auto cell : vector<Cell*>(module->selected_cells()))
                        {
                                        cell->setPort(ID::WR_DATA, wr_data_port);
                                }
 
-                               if (cell->type.in(ID($dlatch), ID($adlatch), ID($dlatchsr)))
-                               {
-                                       bool enpol = cell->parameters[ID::EN_POLARITY].as_bool();
-
-                                       SigSpec sig_en = cell->getPort(ID::EN);
-                                       SigSpec sig_d = cell->getPort(ID::D);
-                                       SigSpec sig_q = cell->getPort(ID::Q);
-
-                                       log("Replacing %s.%s (%s): EN=%s, D=%s, Q=%s\n",
-                                                       log_id(module), log_id(cell), log_id(cell->type),
-                                                       log_signal(sig_en), log_signal(sig_d), log_signal(sig_q));
-
-                                       sig_en = wrap_async_control(module, sig_en, enpol);
-
-                                       Wire *past_q = module->addWire(NEW_ID, GetSize(sig_q));
-                                       module->addFf(NEW_ID, sig_q, past_q);
-
-                                       if (cell->type == ID($dlatch))
-                                       {
-                                               module->addMux(NEW_ID, past_q, sig_d, sig_en, sig_q);
-                                       }
-                                       else if (cell->type == ID($adlatch))
-                                       {
-                                               SigSpec t = module->Mux(NEW_ID, past_q, sig_d, sig_en);
-                                               SigSpec arst = wrap_async_control(module, cell->getPort(ID::ARST), cell->parameters[ID::ARST_POLARITY].as_bool());
-                                               Const rstval = cell->parameters[ID::ARST_VALUE];
-
-                                               module->addMux(NEW_ID, t, rstval, arst, sig_q);
-                                       }
-                                       else
-                                       {
-                                               SigSpec t = module->Mux(NEW_ID, past_q, sig_d, sig_en);
-
-                                               SigSpec s = wrap_async_control(module, cell->getPort(ID::SET), cell->parameters[ID::SET_POLARITY].as_bool());
-                                               t = module->Or(NEW_ID, t, s);
+                               SigSpec qval;
+                               if (RTLIL::builtin_ff_cell_types().count(cell->type)) {
+                                       FfData ff(&initvals, cell);
 
-                                               SigSpec c = wrap_async_control(module, cell->getPort(ID::CLR), cell->parameters[ID::CLR_POLARITY].as_bool());
-                                               c = module->Not(NEW_ID, c);
-                                               module->addAnd(NEW_ID, t, c, sig_q);
+                                       if (ff.has_d && !ff.has_clk && !ff.has_en) {
+                                               // Already a $ff or $_FF_ cell.
+                                               continue;
                                        }
 
-                                       Const initval;
-                                       bool assign_initval = false;
-                                       for (int i = 0; i < GetSize(sig_d); i++) {
-                                               SigBit qbit = sigmap(sig_q[i]);
-                                               if (initbits.count(qbit)) {
-                                                       initval.bits.push_back(initbits.at(qbit));
-                                                       del_initbits.insert(qbit);
-                                               } else
-                                                       initval.bits.push_back(State::Sx);
-                                               if (initval.bits.back() != State::Sx)
-                                                       assign_initval = true;
+                                       Wire *past_q = module->addWire(NEW_ID, ff.width);
+                                       if (!ff.is_fine) {
+                                               module->addFf(NEW_ID, ff.sig_q, past_q);
+                                       } else {
+                                               module->addFfGate(NEW_ID, ff.sig_q, past_q);
                                        }
+                                       if (!ff.val_init.is_fully_undef())
+                                               initvals.set_init(past_q, ff.val_init);
+
+                                       if (ff.has_clk) {
+                                               SigSpec sig_d = ff.sig_d;
+                                               if (ff.has_srst && ff.has_en && ff.ce_over_srst) {
+                                                       if (!ff.is_fine) {
+                                                               if (ff.pol_srst)
+                                                                       sig_d = module->Mux(NEW_ID, sig_d, ff.val_srst, ff.sig_srst);
+                                                               else
+                                                                       sig_d = module->Mux(NEW_ID, ff.val_srst, sig_d, ff.sig_srst);
+                                                       } else {
+                                                               if (ff.pol_srst)
+                                                                       sig_d = module->MuxGate(NEW_ID, sig_d, ff.val_srst[0], ff.sig_srst);
+                                                               else
+                                                                       sig_d = module->MuxGate(NEW_ID, ff.val_srst[0], sig_d, ff.sig_srst);
+                                                       }
+                                               }
 
-                                       if (assign_initval)
-                                               past_q->attributes[ID::init] = initval;
-
-                                       module->remove(cell);
-                                       continue;
-                               }
+                                               if (ff.has_en) {
+                                                       if (!ff.is_fine) {
+                                                               if (ff.pol_en)
+                                                                       sig_d = module->Mux(NEW_ID, ff.sig_q, sig_d, ff.sig_en);
+                                                               else
+                                                                       sig_d = module->Mux(NEW_ID, sig_d, ff.sig_q, ff.sig_en);
+                                                       } else {
+                                                               if (ff.pol_en)
+                                                                       sig_d = module->MuxGate(NEW_ID, ff.sig_q, sig_d, ff.sig_en);
+                                                               else
+                                                                       sig_d = module->MuxGate(NEW_ID, sig_d, ff.sig_q, ff.sig_en);
+                                                       }
+                                               }
 
-                               bool word_dff = cell->type.in(ID($dff), ID($adff), ID($dffsr));
-                               if (word_dff || cell->type.in(ID($_DFF_N_), ID($_DFF_P_),
-                                               ID($_DFF_NN0_), ID($_DFF_NN1_), ID($_DFF_NP0_), ID($_DFF_NP1_),
-                                               ID($_DFF_PP0_), ID($_DFF_PP1_), ID($_DFF_PN0_), ID($_DFF_PN1_),
-                                               ID($_DFFSR_NNN_), ID($_DFFSR_NNP_), ID($_DFFSR_NPN_), ID($_DFFSR_NPP_),
-                                               ID($_DFFSR_PNN_), ID($_DFFSR_PNP_), ID($_DFFSR_PPN_), ID($_DFFSR_PPP_)))
-                               {
-                                       bool clkpol;
-                                       SigSpec clk;
-                                       if (word_dff) {
-                                               clkpol = cell->parameters[ID::CLK_POLARITY].as_bool();
-                                               clk = cell->getPort(ID::CLK);
-                                       }
-                                       else {
-                                               if (cell->type.in(ID($_DFF_P_), ID($_DFF_N_),
-                                                                       ID($_DFF_NN0_), ID($_DFF_NN1_), ID($_DFF_NP0_), ID($_DFF_NP1_),
-                                                                       ID($_DFF_PP0_), ID($_DFF_PP1_), ID($_DFF_PN0_), ID($_DFF_PN1_)))
-                                                       clkpol = cell->type[6] == 'P';
-                                               else if (cell->type.in(ID($_DFFSR_NNN_), ID($_DFFSR_NNP_), ID($_DFFSR_NPN_), ID($_DFFSR_NPP_),
-                                                                       ID($_DFFSR_PNN_), ID($_DFFSR_PNP_), ID($_DFFSR_PPN_), ID($_DFFSR_PPP_)))
-                                                       clkpol = cell->type[8] == 'P';
-                                               else log_abort();
-                                               clk = cell->getPort(ID::C);
-                                       }
+                                               if (ff.has_srst && !(ff.has_en && ff.ce_over_srst)) {
+                                                       if (!ff.is_fine) {
+                                                               if (ff.pol_srst)
+                                                                       sig_d = module->Mux(NEW_ID, sig_d, ff.val_srst, ff.sig_srst);
+                                                               else
+                                                                       sig_d = module->Mux(NEW_ID, ff.val_srst, sig_d, ff.sig_srst);
+                                                       } else {
+                                                               if (ff.pol_srst)
+                                                                       sig_d = module->MuxGate(NEW_ID, sig_d, ff.val_srst[0], ff.sig_srst);
+                                                               else
+                                                                       sig_d = module->MuxGate(NEW_ID, ff.val_srst[0], sig_d, ff.sig_srst);
+                                                       }
+                                               }
 
-                                       Wire *past_clk = module->addWire(NEW_ID);
-                                       past_clk->attributes[ID::init] = clkpol ? State::S1 : State::S0;
+                                               Wire *past_clk = module->addWire(NEW_ID);
+                                               initvals.set_init(past_clk, ff.pol_clk ? State::S1 : State::S0);
 
-                                       if (word_dff)
-                                               module->addFf(NEW_ID, clk, past_clk);
-                                       else
-                                               module->addFfGate(NEW_ID, clk, past_clk);
+                                               if (!ff.is_fine)
+                                                       module->addFf(NEW_ID, ff.sig_clk, past_clk);
+                                               else
+                                                       module->addFfGate(NEW_ID, ff.sig_clk, past_clk);
 
-                                       SigSpec sig_d = cell->getPort(ID::D);
-                                       SigSpec sig_q = cell->getPort(ID::Q);
+                                               log("Replacing %s.%s (%s): CLK=%s, D=%s, Q=%s\n",
+                                                               log_id(module), log_id(cell), log_id(cell->type),
+                                                               log_signal(ff.sig_clk), log_signal(ff.sig_d), log_signal(ff.sig_q));
 
-                                       log("Replacing %s.%s (%s): CLK=%s, D=%s, Q=%s\n",
-                                                       log_id(module), log_id(cell), log_id(cell->type),
-                                                       log_signal(clk), log_signal(sig_d), log_signal(sig_q));
+                                               SigSpec clock_edge_pattern;
 
-                                       SigSpec clock_edge_pattern;
+                                               if (ff.pol_clk) {
+                                                       clock_edge_pattern.append(State::S0);
+                                                       clock_edge_pattern.append(State::S1);
+                                               } else {
+                                                       clock_edge_pattern.append(State::S1);
+                                                       clock_edge_pattern.append(State::S0);
+                                               }
 
-                                       if (clkpol) {
-                                               clock_edge_pattern.append(State::S0);
-                                               clock_edge_pattern.append(State::S1);
-                                       } else {
-                                               clock_edge_pattern.append(State::S1);
-                                               clock_edge_pattern.append(State::S0);
-                                       }
+                                               SigSpec clock_edge = module->Eqx(NEW_ID, {ff.sig_clk, SigSpec(past_clk)}, clock_edge_pattern);
 
-                                       SigSpec clock_edge = module->Eqx(NEW_ID, {clk, SigSpec(past_clk)}, clock_edge_pattern);
+                                               Wire *past_d = module->addWire(NEW_ID, ff.width);
+                                               if (!ff.is_fine)
+                                                       module->addFf(NEW_ID, sig_d, past_d);
+                                               else
+                                                       module->addFfGate(NEW_ID, sig_d, past_d);
 
-                                       Wire *past_d = module->addWire(NEW_ID, GetSize(sig_d));
-                                       Wire *past_q = module->addWire(NEW_ID, GetSize(sig_q));
-                                       if (word_dff) {
-                                               module->addFf(NEW_ID, sig_d, past_d);
-                                               module->addFf(NEW_ID, sig_q, past_q);
-                                       }
-                                       else {
-                                               module->addFfGate(NEW_ID, sig_d, past_d);
-                                               module->addFfGate(NEW_ID, sig_q, past_q);
-                                       }
+                                               if (!ff.val_init.is_fully_undef())
+                                                       initvals.set_init(past_d, ff.val_init);
 
-                                       if (cell->type == ID($adff))
-                                       {
-                                               SigSpec arst = wrap_async_control(module, cell->getPort(ID::ARST), cell->parameters[ID::ARST_POLARITY].as_bool());
-                                               SigSpec qval = module->Mux(NEW_ID, past_q, past_d, clock_edge);
-                                               Const rstval = cell->parameters[ID::ARST_VALUE];
+                                               if (!ff.is_fine)
+                                                       qval = module->Mux(NEW_ID, past_q, past_d, clock_edge);
+                                               else
+                                                       qval = module->MuxGate(NEW_ID, past_q, past_d, clock_edge);
+                                       } else if (ff.has_d) {
 
-                                               module->addMux(NEW_ID, qval, rstval, arst, sig_q);
-                                       }
-                                       else
-                                       if (cell->type.in(ID($_DFF_NN0_), ID($_DFF_NN1_), ID($_DFF_NP0_), ID($_DFF_NP1_),
-                                               ID($_DFF_PP0_), ID($_DFF_PP1_), ID($_DFF_PN0_), ID($_DFF_PN1_)))
-                                       {
-                                               SigSpec arst = wrap_async_control_gate(module, cell->getPort(ID::R), cell->type[7] == 'P');
-                                               SigSpec qval = module->MuxGate(NEW_ID, past_q, past_d, clock_edge);
-                                               SigBit rstval = (cell->type[8] == '1');
+                                               log("Replacing %s.%s (%s): EN=%s, D=%s, Q=%s\n",
+                                                               log_id(module), log_id(cell), log_id(cell->type),
+                                                               log_signal(ff.sig_en), log_signal(ff.sig_d), log_signal(ff.sig_q));
 
-                                               module->addMuxGate(NEW_ID, qval, rstval, arst, sig_q);
-                                       }
-                                       else
-                                       if (cell->type == ID($dffsr))
-                                       {
-                                               SigSpec qval = module->Mux(NEW_ID, past_q, past_d, clock_edge);
-                                               SigSpec setval = wrap_async_control(module, cell->getPort(ID::SET), cell->parameters[ID::SET_POLARITY].as_bool());
-                                               SigSpec clrval = wrap_async_control(module, cell->getPort(ID::CLR), cell->parameters[ID::CLR_POLARITY].as_bool());
+                                               SigSpec sig_en = wrap_async_control(module, ff.sig_en, ff.pol_en);
 
-                                               clrval = module->Not(NEW_ID, clrval);
-                                               qval = module->Or(NEW_ID, qval, setval);
-                                               module->addAnd(NEW_ID, qval, clrval, sig_q);
-                                       }
-                                       else
-                                       if (cell->type.in(ID($_DFFSR_NNN_), ID($_DFFSR_NNP_), ID($_DFFSR_NPN_), ID($_DFFSR_NPP_),
-                                               ID($_DFFSR_PNN_), ID($_DFFSR_PNP_), ID($_DFFSR_PPN_), ID($_DFFSR_PPP_)))
-                                       {
-                                               SigSpec qval = module->MuxGate(NEW_ID, past_q, past_d, clock_edge);
-                                               SigSpec setval = wrap_async_control_gate(module, cell->getPort(ID::S), cell->type[9] == 'P');
-                                               SigSpec clrval = wrap_async_control_gate(module, cell->getPort(ID::R), cell->type[10] == 'P');
+                                               if (!ff.is_fine)
+                                                       qval = module->Mux(NEW_ID, past_q, ff.sig_d, sig_en);
+                                               else
+                                                       qval = module->MuxGate(NEW_ID, past_q, ff.sig_d, sig_en);
+                                       } else {
 
-                                               clrval = module->NotGate(NEW_ID, clrval);
-                                               qval = module->OrGate(NEW_ID, qval, setval);
-                                               module->addAndGate(NEW_ID, qval, clrval, sig_q);
-                                       }
-                                       else if (cell->type == ID($dff))
-                                       {
-                                               module->addMux(NEW_ID, past_q, past_d, clock_edge, sig_q);
-                                       }
-                                       else
-                                       {
-                                               module->addMuxGate(NEW_ID, past_q, past_d, clock_edge, sig_q);
-                                       }
+                                               log("Replacing %s.%s (%s): SET=%s, CLR=%s, Q=%s\n",
+                                                               log_id(module), log_id(cell), log_id(cell->type),
+                                                               log_signal(ff.sig_set), log_signal(ff.sig_clr), log_signal(ff.sig_q));
 
-                                       Const initval;
-                                       bool assign_initval = false;
-                                       for (int i = 0; i < GetSize(sig_d); i++) {
-                                               SigBit qbit = sigmap(sig_q[i]);
-                                               if (initbits.count(qbit)) {
-                                                       initval.bits.push_back(initbits.at(qbit));
-                                                       del_initbits.insert(qbit);
-                                               } else
-                                                       initval.bits.push_back(State::Sx);
-                                               if (initval.bits.back() != State::Sx)
-                                                       assign_initval = true;
+                                               qval = past_q;
                                        }
 
-                                       if (assign_initval) {
-                                               past_d->attributes[ID::init] = initval;
-                                               past_q->attributes[ID::init] = initval;
+                                       if (ff.has_sr) {
+                                               SigSpec setval = wrap_async_control(module, ff.sig_set, ff.pol_set);
+                                               SigSpec clrval = wrap_async_control(module, ff.sig_clr, ff.pol_clr);
+                                               if (!ff.is_fine) {
+                                                       clrval = module->Not(NEW_ID, clrval);
+                                                       qval = module->Or(NEW_ID, qval, setval);
+                                                       module->addAnd(NEW_ID, qval, clrval, ff.sig_q);
+                                               } else {
+                                                       clrval = module->NotGate(NEW_ID, clrval);
+                                                       qval = module->OrGate(NEW_ID, qval, setval);
+                                                       module->addAndGate(NEW_ID, qval, clrval, ff.sig_q);
+                                               }
+                                       } else if (ff.has_arst) {
+                                               SigSpec arst = wrap_async_control(module, ff.sig_arst, ff.pol_arst);
+                                               if (!ff.is_fine)
+                                                       module->addMux(NEW_ID, qval, ff.val_arst, arst, ff.sig_q);
+                                               else
+                                                       module->addMuxGate(NEW_ID, qval, ff.val_arst[0], arst, ff.sig_q);
+                                       } else {
+                                               module->connect(ff.sig_q, qval);
                                        }
 
+                                       initvals.remove_init(ff.sig_q);
                                        module->remove(cell);
                                        continue;
                                }
                        }
-
-                       for (auto wire : module->wires())
-                               if (wire->attributes.count(ID::init) > 0)
-                               {
-                                       bool delete_initattr = true;
-                                       Const initval = wire->attributes.at(ID::init);
-                                       SigSpec initsig = sigmap(wire);
-
-                                       for (int i = 0; i < GetSize(initval) && i < GetSize(initsig); i++)
-                                               if (del_initbits.count(initsig[i]) > 0)
-                                                       initval[i] = State::Sx;
-                                               else if (initval[i] != State::Sx)
-                                                       delete_initattr = false;
-
-                                       if (delete_initattr)
-                                               wire->attributes.erase(ID::init);
-                                       else
-                                               wire->attributes.at(ID::init) = initval;
-                               }
                }
 
        }