Progress in pmgen
authorClifford Wolf <clifford@clifford.at>
Sun, 13 Jan 2019 16:03:58 +0000 (17:03 +0100)
committerClifford Wolf <clifford@clifford.at>
Tue, 15 Jan 2019 10:23:25 +0000 (11:23 +0100)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/pmgen/ice40_dsp.cc
passes/pmgen/ice40_dsp.pmg
passes/pmgen/pmgen.py

index 049ef6c0e8f5a4d7c220e86047d35fd8d2ebfabf..a8f63ebfe353a2c7914e7a2d8a4bd83b197a7b1b 100644 (file)
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
 
-void ice40_dsp_accept(ice40_dsp_pm *pm)
-{
-       log("\n");
-       log("mul: %s\n", pm->st.mul ? log_id(pm->st.mul) : "--");
-       log("ffA: %s\n", pm->st.ffA ? log_id(pm->st.ffA) : "--");
-       log("ffB: %s\n", pm->st.ffB ? log_id(pm->st.ffB) : "--");
-       log("ffY: %s\n", pm->st.ffY ? log_id(pm->st.ffY) : "--");
-
-       pm->blacklist(pm->st.mul);
-}
-
 struct Ice40DspPass : public Pass {
        Ice40DspPass() : Pass("ice40_dsp", "iCE40: map multipliers") { }
        void help() YS_OVERRIDE
@@ -64,7 +53,23 @@ struct Ice40DspPass : public Pass {
                for (auto module : design->selected_modules())
                {
                        ice40_dsp_pm pm(module, module->cells());
-                       pm.run(ice40_dsp_accept);
+                       pm.match([&]()
+                       {
+                               log("\n");
+                               log("ffA:   %s\n", log_id(pm.st.ffA, "--"));
+                               log("ffB:   %s\n", log_id(pm.st.ffB, "--"));
+                               log("mul:   %s\n", log_id(pm.st.mul, "--"));
+                               log("ffY:   %s\n", log_id(pm.st.ffY, "--"));
+                               log("addAB: %s\n", log_id(pm.st.addAB, "--"));
+                               log("muxAB: %s\n", log_id(pm.st.muxAB, "--"));
+                               log("ffS:   %s\n", log_id(pm.st.ffS, "--"));
+
+                               pm.blacklist(pm.st.mul);
+                               pm.blacklist(pm.st.ffA);
+                               pm.blacklist(pm.st.ffB);
+                               pm.blacklist(pm.st.ffY);
+                               pm.blacklist(pm.st.ffS);
+                       });
                }
        }
 } Ice40DspPass;
index a2937ddc686125da5b83d850afb26d513a58f778..1370cb66aa57cfb965d9860684cfe72e95197a33 100644 (file)
@@ -1,6 +1,7 @@
 state <SigBit> clock
 state <bool> clock_pol clock_vld
-state <SigSpec> sigA sigB sigY
+state <SigSpec> sigA sigB sigY sigS
+state <Cell*> addAB muxAB
 
 match mul
        select mul->type.in($mul)
@@ -11,14 +12,14 @@ endmatch
 match ffA
        select ffA->type.in($dff)
        // select nusers(port(ffA, \Q)) == 2
-       filter <SigSpec> port(ffA, \Q) === port(mul, \A)
+       index <SigSpec> port(ffA, \Q) === port(mul, \A)
        optional
 endmatch
 
 code sigA clock clock_pol clock_vld
        sigA = port(mul, \A);
 
-       if (ffA != nullptr) {
+       if (ffA) {
                sigA = port(ffA, \D);
 
                clock = port(ffA, \CLK).as_bit();
@@ -30,14 +31,14 @@ endcode
 match ffB
        select ffB->type.in($dff)
        // select nusers(port(ffB, \Q)) == 2
-       filter <SigSpec> port(ffB, \Q) === port(mul, \B)
+       index <SigSpec> port(ffB, \Q) === port(mul, \B)
        optional
 endmatch
 
 code sigB clock clock_pol clock_vld
        sigB = port(mul, \B);
 
-       if (ffB != nullptr) {
+       if (ffB) {
                sigB = port(ffB, \D);
                SigBit c = port(ffB, \CLK).as_bit();
                bool cp = param(ffB, \CLK_POLARITY).as_bool();
@@ -54,14 +55,14 @@ endcode
 match ffY
        select ffY->type.in($dff)
        select nusers(port(ffY, \D)) == 2
-       filter <SigSpec> port(ffY, \D) === port(mul, \Y)
+       index <SigSpec> port(ffY, \D) === port(mul, \Y)
        optional
 endmatch
 
 code sigY clock clock_pol clock_vld
        sigY = port(mul, \Y);
 
-       if (ffY != nullptr) {
+       if (ffY) {
                sigY = port(ffY, \D);
                SigBit c = port(ffY, \CLK).as_bit();
                bool cp = param(ffY, \CLK_POLARITY).as_bool();
@@ -74,3 +75,62 @@ code sigY clock clock_pol clock_vld
                clock_vld = true;
        }
 endcode
+
+match addA
+       select addA->type.in($add, $sub)
+       select nusers(port(addA, \A)) == 2
+       index <SigSpec> port(addA, \A) === sigY
+       optional
+endmatch
+
+match addB
+       if !addA
+       select addB->type.in($add, $sub)
+       select nusers(port(addB, \B)) == 2
+       index <SigSpec> port(addB, \B) === sigY
+       optional
+endmatch
+
+code addAB sigS
+       if (addA) {
+               addAB = addA;
+               sigS = port(addA, \B);
+       }
+       if (addB) {
+               addAB = addB;
+               sigS = port(addB, \A);
+       }
+endcode
+
+match muxA
+       if addAB
+       select muxA->type.in($mux)
+       select nusers(port(muxA, \A)) == 2
+       index <SigSpec> port(muxA, \A) === port(addAB, \Y)
+       optional
+endmatch
+
+match muxB
+       if addAB
+       if !muxA
+       select muxB->type.in($mux)
+       select nusers(port(muxB, \B)) == 2
+       index <SigSpec> port(muxB, \B) === port(addAB, \Y)
+       optional
+endmatch
+
+code muxAB
+       muxAB = addAB;
+       if (muxA)
+               muxAB = muxA;
+       if (muxB)
+               muxAB = muxB;
+endcode
+
+match ffS
+       if muxAB
+       select ffS->type.in($dff)
+       select nusers(port(ffS, \D)) == 2
+       index <SigSpec> port(ffS, \D) === port(muxAB, \Y)
+       index <SigSpec> port(ffS, \Q) === sigS
+endmatch
index 7d33c4fc991ff556d62f2deff4cb185aacf812eb..88d60d29857119d82f4b1df740ad9929bcbab828 100644 (file)
@@ -102,7 +102,9 @@ with open("%s.pmg" % prefix, "r") as f:
             block["cell"] = line[1]
             state_types[line[1]] = "Cell*";
 
+            block["if"] = list()
             block["select"] = list()
+            block["index"] = list()
             block["filter"] = list()
             block["optional"] = False
 
@@ -113,15 +115,25 @@ with open("%s.pmg" % prefix, "r") as f:
                 if len(a) == 0 or a[0].startswith("//"): continue
                 if a[0] == "endmatch": break
 
+                if a[0] == "if":
+                    b = l.lstrip()[2:]
+                    block["if"].append(rewrite_cpp(b.strip()))
+                    continue
+
                 if a[0] == "select":
                     b = l.lstrip()[6:]
                     block["select"].append(rewrite_cpp(b.strip()))
                     continue
 
-                if a[0] == "filter":
-                    m = re.match(r"^\s*filter\s+<(.*?)>\s+(.*?)\s*===\s*(.*?)\s*$", l)
+                if a[0] == "index":
+                    m = re.match(r"^\s*index\s+<(.*?)>\s+(.*?)\s*===\s*(.*?)\s*$", l)
                     assert m
-                    block["filter"].append((m.group(1), rewrite_cpp(m.group(2)), rewrite_cpp(m.group(3))))
+                    block["index"].append((m.group(1), rewrite_cpp(m.group(2)), rewrite_cpp(m.group(3))))
+                    continue
+
+                if a[0] == "filter":
+                    b = l.lstrip()[6:]
+                    block["filter"].append(rewrite_cpp(b.strip()))
                     continue
 
                 if a[0] == "optional":
@@ -167,15 +179,15 @@ with open("%s_pm.h" % prefix, "w") as f:
     print("struct {}_pm {{".format(prefix), file=f)
     print("  Module *module;", file=f)
     print("  SigMap sigmap;", file=f)
-    print("  std::function<void(struct {}_pm*)> on_accept;".format(prefix), file=f)
+    print("  std::function<void()> on_accept;".format(prefix), file=f)
     print("", file=f)
 
     for index in range(len(blocks)):
         block = blocks[index]
         if block["type"] == "match":
             index_types = list()
-            for filt in block["filter"]:
-                index_types.append(filt[0])
+            for entry in block["index"]:
+                index_types.append(entry[0])
             print("  typedef std::tuple<{}> index_{}_key_type;".format(", ".join(index_types), index), file=f)
             print("  dict<index_{}_key_type, vector<Cell*>> index_{};".format(index, index), file=f)
     print("  dict<SigBit, pool<Cell*>> sigusers;", file=f)
@@ -208,7 +220,8 @@ with open("%s_pm.h" % prefix, "w") as f:
     print("", file=f)
 
     print("  void blacklist(Cell *cell) {", file=f)
-    print("    blacklist_cells.insert(cell);", file=f)
+    print("    if (cell != nullptr)", file=f)
+    print("      blacklist_cells.insert(cell);", file=f)
     print("  }", file=f)
     print("", file=f)
 
@@ -248,6 +261,8 @@ with open("%s_pm.h" % prefix, "w") as f:
     print("    for (auto cell : cells) {", file=f)
     print("      for (auto &conn : cell->connections())", file=f)
     print("        add_siguser(conn.second, cell);", file=f)
+    print("    }", file=f)
+    print("    for (auto cell : cells) {", file=f)
 
     for index in range(len(blocks)):
         block = blocks[index]
@@ -257,8 +272,8 @@ with open("%s_pm.h" % prefix, "w") as f:
             for expr in block["select"]:
                 print("        if (!({})) break;".format(expr), file=f)
             print("        index_{}_key_type key;".format(index), file=f)
-            for field, filt in enumerate(block["filter"]):
-                print("        std::get<{}>(key) = {};".format(field, filt[1]), file=f)
+            for field, entry in enumerate(block["index"]):
+                print("        std::get<{}>(key) = {};".format(field, entry[1]), file=f)
             print("        index_{}[key].push_back(cell);".format(index), file=f)
             print("      } while (0);", file=f)
 
@@ -266,15 +281,20 @@ with open("%s_pm.h" % prefix, "w") as f:
     print("  }", file=f)
     print("", file=f)
 
-    print("  void run(std::function<void(struct {}_pm*)> on_accept_f) {{".format(prefix), file=f)
+    print("  void match(std::function<void()> on_accept_f) {{".format(prefix), file=f)
     print("    on_accept = on_accept_f;", file=f)
     print("    rollback = 0;", file=f)
+    for s, t in sorted(state_types.items()):
+        if t.endswith("*"):
+            print("    st.{} = nullptr;".format(s), file=f)
+        else:
+            print("    st.{} = {}();".format(s, t), file=f)
     print("    block_0();", file=f)
     print("  }", file=f)
     print("", file=f)
 
     print("#define reject break", file=f)
-    print("#define accept do { on_accept(this); check_blacklist(); if (rollback) goto rollback_label; } while(0)", file=f)
+    print("#define accept do { on_accept(); check_blacklist(); if (rollback) goto rollback_label; } while(0)", file=f)
     print("", file=f)
 
     for index in range(len(blocks)):
@@ -323,7 +343,7 @@ with open("%s_pm.h" % prefix, "w") as f:
             print("", file=f)
             for s in sorted(restore_st):
                 t = state_types[s]
-                print("    {} backup_{} = st.{};".format(t, s, s), file=f)
+                print("    {} backup_{} = {};".format(t, s, s), file=f)
 
         if block["type"] == "code":
             print("", file=f)
@@ -336,24 +356,42 @@ with open("%s_pm.h" % prefix, "w") as f:
             print("rollback_label: YS_ATTRIBUTE(unused);", file=f)
             print("    } while (0);", file=f)
 
-            if len(restore_st):
+            if len(restore_st) or len(nonconst_st):
                 print("", file=f)
                 for s in sorted(restore_st):
                     t = state_types[s]
-                    print("    st.{} = backup_{};".format(s, s), file=f)
+                    print("    {} = backup_{};".format(s, s), file=f)
+                for s in sorted(nonconst_st):
+                    if s not in restore_st:
+                        t = state_types[s]
+                        if t.endswith("*"):
+                            print("    {} = nullptr;".format(s), file=f)
+                        else:
+                            print("    {} = {}();".format(s, t), file=f)
 
         elif block["type"] == "match":
             assert len(restore_st) == 0
 
+            if len(block["if"]):
+                for expr in block["if"]:
+                    print("", file=f)
+                    print("    if (!({})) {{".format(expr), file=f)
+                    print("      {} = nullptr;".format(block["cell"]), file=f)
+                    print("      block_{}();".format(index+1), file=f)
+                    print("      return;", file=f)
+                    print("    }", file=f)
+
             print("", file=f)
             print("    index_{}_key_type key;".format(index), file=f)
-            for field, filt in enumerate(block["filter"]):
-                print("    std::get<{}>(key) = {};".format(field, filt[2]), file=f)
+            for field, entry in enumerate(block["index"]):
+                print("    std::get<{}>(key) = {};".format(field, entry[2]), file=f)
             print("    const vector<Cell*> &cells = index_{}[key];".format(index), file=f)
 
             print("", file=f)
             print("    for (int idx = 0; idx < GetSize(cells); idx++) {", file=f)
             print("      {} = cells[idx];".format(block["cell"]), file=f)
+            for expr in block["filter"]:
+                print("      if (!({})) continue;".format(expr), file=f)
             print("      block_{}();".format(index+1), file=f)
             print("      if (rollback) {", file=f)
             print("        if (rollback != {}) {{".format(index+1), file=f)
@@ -382,7 +420,7 @@ with open("%s_pm.h" % prefix, "w") as f:
     print("", file=f)
 
     print("  void block_{}() {{".format(len(blocks)), file=f)
-    print("    on_accept(this);", file=f)
+    print("    on_accept();", file=f)
     print("    check_blacklist();", file=f)
     print("  }", file=f)
     print("};", file=f)