Add pmgen support for multiple patterns in one matcher
authorClifford Wolf <clifford@clifford.at>
Mon, 29 Apr 2019 11:02:05 +0000 (13:02 +0200)
committerClifford Wolf <clifford@clifford.at>
Mon, 29 Apr 2019 11:02:05 +0000 (13:02 +0200)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/pmgen/ice40_dsp.cc
passes/pmgen/ice40_dsp.pmg
passes/pmgen/pmgen.py

index 3a054a4632bc6966d6c40bde5328ba6797980c13..36ba1dabeaf6fb99518e4e018c708e69698f8d1c 100644 (file)
@@ -26,40 +26,42 @@ PRIVATE_NAMESPACE_BEGIN
 
 void create_ice40_dsp(ice40_dsp_pm &pm)
 {
+       auto &st = pm.st_ice40_dsp;
+
 #if 0
        log("\n");
-       log("ffA:   %s\n", log_id(pm.st.ffA, "--"));
-       log("ffB:   %s\n", log_id(pm.st.ffB, "--"));
-       log("mul:   %s\n", log_id(pm.st.mul, "--"));
-       log("ffY:   %s\n", log_id(pm.st.ffY, "--"));
-       log("addAB: %s\n", log_id(pm.st.addAB, "--"));
-       log("muxAB: %s\n", log_id(pm.st.muxAB, "--"));
-       log("ffS:   %s\n", log_id(pm.st.ffS, "--"));
+       log("ffA:   %s\n", log_id(st.ffA, "--"));
+       log("ffB:   %s\n", log_id(st.ffB, "--"));
+       log("mul:   %s\n", log_id(st.mul, "--"));
+       log("ffY:   %s\n", log_id(st.ffY, "--"));
+       log("addAB: %s\n", log_id(st.addAB, "--"));
+       log("muxAB: %s\n", log_id(st.muxAB, "--"));
+       log("ffS:   %s\n", log_id(st.ffS, "--"));
 #endif
 
-       log("Checking %s.%s for iCE40 DSP inference.\n", log_id(pm.module), log_id(pm.st.mul));
+       log("Checking %s.%s for iCE40 DSP inference.\n", log_id(pm.module), log_id(st.mul));
 
-       if (GetSize(pm.st.sigA) > 16) {
-               log("  input A (%s) is too large (%d > 16).\n", log_signal(pm.st.sigA), GetSize(pm.st.sigA));
+       if (GetSize(st.sigA) > 16) {
+               log("  input A (%s) is too large (%d > 16).\n", log_signal(st.sigA), GetSize(st.sigA));
                return;
        }
 
-       if (GetSize(pm.st.sigB) > 16) {
-               log("  input B (%s) is too large (%d > 16).\n", log_signal(pm.st.sigB), GetSize(pm.st.sigB));
+       if (GetSize(st.sigB) > 16) {
+               log("  input B (%s) is too large (%d > 16).\n", log_signal(st.sigB), GetSize(st.sigB));
                return;
        }
 
-       if (GetSize(pm.st.sigS) > 32) {
-               log("  accumulator (%s) is too large (%d > 32).\n", log_signal(pm.st.sigS), GetSize(pm.st.sigS));
+       if (GetSize(st.sigS) > 32) {
+               log("  accumulator (%s) is too large (%d > 32).\n", log_signal(st.sigS), GetSize(st.sigS));
                return;
        }
 
-       if (GetSize(pm.st.sigY) > 32) {
-               log("  output (%s) is too large (%d > 32).\n", log_signal(pm.st.sigY), GetSize(pm.st.sigY));
+       if (GetSize(st.sigY) > 32) {
+               log("  output (%s) is too large (%d > 32).\n", log_signal(st.sigY), GetSize(st.sigY));
                return;
        }
 
-       bool mul_signed = pm.st.mul->getParam("\\A_SIGNED").as_bool();
+       bool mul_signed = st.mul->getParam("\\A_SIGNED").as_bool();
 
        if (mul_signed) {
                log("  inference of signed iCE40 DSP arithmetic is currently not supported.\n");
@@ -69,21 +71,21 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
        log("  replacing $mul with SB_MAC16 cell.\n");
 
        Cell *cell = pm.module->addCell(NEW_ID, "\\SB_MAC16");
-       pm.module->swap_names(cell, pm.st.mul);
+       pm.module->swap_names(cell, st.mul);
 
        // SB_MAC16 Input Interface
 
-       SigSpec A = pm.st.sigA;
+       SigSpec A = st.sigA;
        A.extend_u0(16, mul_signed);
 
-       SigSpec B = pm.st.sigB;
+       SigSpec B = st.sigB;
        B.extend_u0(16, mul_signed);
 
        SigSpec CD;
-       if (pm.st.muxA)
-               CD = pm.st.muxA->getPort("\\B");
-       if (pm.st.muxB)
-               CD = pm.st.muxB->getPort("\\A");
+       if (st.muxA)
+               CD = st.muxA->getPort("\\B");
+       if (st.muxB)
+               CD = st.muxB->getPort("\\A");
        CD.extend_u0(32, mul_signed);
 
        cell->setPort("\\A", A);
@@ -91,8 +93,8 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
        cell->setPort("\\C", CD.extract(0, 16));
        cell->setPort("\\D", CD.extract(16, 16));
 
-       cell->setParam("\\A_REG", pm.st.ffA ? State::S1 : State::S0);
-       cell->setParam("\\B_REG", pm.st.ffB ? State::S1 : State::S0);
+       cell->setParam("\\A_REG", st.ffA ? State::S1 : State::S0);
+       cell->setParam("\\B_REG", st.ffB ? State::S1 : State::S0);
 
        cell->setPort("\\AHOLD", State::S0);
        cell->setPort("\\BHOLD", State::S0);
@@ -102,25 +104,25 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
        cell->setPort("\\IRSTTOP", State::S0);
        cell->setPort("\\IRSTBOT", State::S0);
 
-       if (pm.st.clock_vld)
+       if (st.clock_vld)
        {
-               cell->setPort("\\CLK", pm.st.clock);
+               cell->setPort("\\CLK", st.clock);
                cell->setPort("\\CE", State::S1);
-               cell->setParam("\\NEG_TRIGGER", pm.st.clock_pol ? State::S0 : State::S1);
+               cell->setParam("\\NEG_TRIGGER", st.clock_pol ? State::S0 : State::S1);
 
-               log("  clock: %s (%s)", log_signal(pm.st.clock), pm.st.clock_pol ? "posedge" : "negedge");
+               log("  clock: %s (%s)", log_signal(st.clock), st.clock_pol ? "posedge" : "negedge");
 
-               if (pm.st.ffA)
-                       log(" ffA:%s", log_id(pm.st.ffA));
+               if (st.ffA)
+                       log(" ffA:%s", log_id(st.ffA));
 
-               if (pm.st.ffB)
-                       log(" ffB:%s", log_id(pm.st.ffB));
+               if (st.ffB)
+                       log(" ffB:%s", log_id(st.ffB));
 
-               if (pm.st.ffY)
-                       log(" ffY:%s", log_id(pm.st.ffY));
+               if (st.ffY)
+                       log(" ffY:%s", log_id(st.ffY));
 
-               if (pm.st.ffS)
-                       log(" ffS:%s", log_id(pm.st.ffS));
+               if (st.ffS)
+                       log(" ffS:%s", log_id(st.ffS));
 
                log("\n");
        }
@@ -144,16 +146,16 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
 
        // SB_MAC16 Output Interface
 
-       SigSpec O = pm.st.ffS ? pm.st.sigS : pm.st.sigY;
+       SigSpec O = st.ffS ? st.sigS : st.sigY;
        if (GetSize(O) < 32)
                O.append(pm.module->addWire(NEW_ID, 32-GetSize(O)));
 
        cell->setPort("\\O", O);
 
-       if (pm.st.addAB) {
-               log("  accumulator %s (%s)\n", log_id(pm.st.addAB), log_id(pm.st.addAB->type));
-               cell->setPort("\\ADDSUBTOP", pm.st.addAB->type == "$add" ? State::S0 : State::S1);
-               cell->setPort("\\ADDSUBBOT", pm.st.addAB->type == "$add" ? State::S0 : State::S1);
+       if (st.addAB) {
+               log("  accumulator %s (%s)\n", log_id(st.addAB), log_id(st.addAB->type));
+               cell->setPort("\\ADDSUBTOP", st.addAB->type == "$add" ? State::S0 : State::S1);
+               cell->setPort("\\ADDSUBBOT", st.addAB->type == "$add" ? State::S0 : State::S1);
        } else {
                cell->setPort("\\ADDSUBTOP", State::S0);
                cell->setPort("\\ADDSUBBOT", State::S0);
@@ -166,10 +168,10 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
        cell->setPort("\\OHOLDBOT", State::S0);
 
        SigSpec acc_reset = State::S0;
-       if (pm.st.muxA)
-               acc_reset = pm.st.muxA->getPort("\\S");
-       if (pm.st.muxB)
-               acc_reset = pm.module->Not(NEW_ID, pm.st.muxB->getPort("\\S"));
+       if (st.muxA)
+               acc_reset = st.muxA->getPort("\\S");
+       if (st.muxB)
+               acc_reset = pm.module->Not(NEW_ID, st.muxB->getPort("\\S"));
 
        cell->setPort("\\OLOADTOP", acc_reset);
        cell->setPort("\\OLOADBOT", acc_reset);
@@ -179,17 +181,17 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
        cell->setParam("\\C_REG", State::S0);
        cell->setParam("\\D_REG", State::S0);
 
-       cell->setParam("\\TOP_8x8_MULT_REG", pm.st.ffY ? State::S1 : State::S0);
-       cell->setParam("\\BOT_8x8_MULT_REG", pm.st.ffY ? State::S1 : State::S0);
-       cell->setParam("\\PIPELINE_16x16_MULT_REG1", pm.st.ffY ? State::S1 : State::S0);
+       cell->setParam("\\TOP_8x8_MULT_REG", st.ffY ? State::S1 : State::S0);
+       cell->setParam("\\BOT_8x8_MULT_REG", st.ffY ? State::S1 : State::S0);
+       cell->setParam("\\PIPELINE_16x16_MULT_REG1", st.ffY ? State::S1 : State::S0);
        cell->setParam("\\PIPELINE_16x16_MULT_REG2", State::S0);
 
-       cell->setParam("\\TOPOUTPUT_SELECT", Const(pm.st.ffS ? 1 : 3, 2));
+       cell->setParam("\\TOPOUTPUT_SELECT", Const(st.ffS ? 1 : 3, 2));
        cell->setParam("\\TOPADDSUB_LOWERINPUT", Const(2, 2));
        cell->setParam("\\TOPADDSUB_UPPERINPUT", State::S0);
        cell->setParam("\\TOPADDSUB_CARRYSELECT", Const(3, 2));
 
-       cell->setParam("\\BOTOUTPUT_SELECT", Const(pm.st.ffS ? 1 : 3, 2));
+       cell->setParam("\\BOTOUTPUT_SELECT", Const(st.ffS ? 1 : 3, 2));
        cell->setParam("\\BOTADDSUB_LOWERINPUT", Const(2, 2));
        cell->setParam("\\BOTADDSUB_UPPERINPUT", State::S0);
        cell->setParam("\\BOTADDSUB_CARRYSELECT", Const(0, 2));
@@ -198,9 +200,9 @@ void create_ice40_dsp(ice40_dsp_pm &pm)
        cell->setParam("\\A_SIGNED", mul_signed ? State::S1 : State::S0);
        cell->setParam("\\B_SIGNED", mul_signed ? State::S1 : State::S0);
 
-       pm.autoremove(pm.st.mul);
-       pm.autoremove(pm.st.ffY);
-       pm.autoremove(pm.st.ffS);
+       pm.autoremove(st.mul);
+       pm.autoremove(st.ffY);
+       pm.autoremove(st.ffS);
 }
 
 struct Ice40DspPass : public Pass {
@@ -230,7 +232,7 @@ struct Ice40DspPass : public Pass {
                extra_args(args, argidx, design);
 
                for (auto module : design->selected_modules())
-                       ice40_dsp_pm(module, module->selected_cells()).run(create_ice40_dsp);
+                       ice40_dsp_pm(module, module->selected_cells()).run_ice40_dsp(create_ice40_dsp);
        }
 } Ice40DspPass;
 
index 96c62e313b3991d90d28ba111cb16c1e534814dc..1f3590d4ecd84610d1bf99bac8fca35267dbe773 100644 (file)
@@ -1,3 +1,5 @@
+pattern ice40_dsp
+
 state <SigBit> clock
 state <bool> clock_pol clock_vld
 state <SigSpec> sigA sigB sigY sigS
index edc1ad7feda0702bc1f8da98df51bea04290a6f1..bb4c9d66bc32e7a66705fc48b65c915315e6dcc1 100644 (file)
@@ -10,14 +10,17 @@ pp = pprint.PrettyPrinter(indent=4)
 prefix = None
 pmgfiles = list()
 outfile = None
+debug = False
 
-opts, args = getopt.getopt(sys.argv[1:], "p:o:")
+opts, args = getopt.getopt(sys.argv[1:], "p:o:d")
 
 for o, a in opts:
     if o == "-p":
-        prefix = o
+        prefix = a
     elif o == "-o":
         outfile = a
+    elif o == "-d":
+        debug = True
 
 if outfile is None:
     outfile = args[-1]
@@ -32,6 +35,8 @@ for a in args:
 
 assert prefix is not None
 
+current_pattern = None
+patterns = dict()
 state_types = dict()
 udata_types = dict()
 blocks = list()
@@ -98,6 +103,7 @@ def rewrite_cpp(s):
     return "".join(t)
 
 def process_pmgfile(f):
+    global current_pattern
     while True:
         line = f.readline()
         if line == "": break
@@ -107,14 +113,31 @@ def process_pmgfile(f):
         if len(cmd) == 0 or cmd[0].startswith("//"): continue
         cmd = cmd[0]
 
+        if cmd == "pattern":
+            if current_pattern is not None:
+                block = dict()
+                block["type"] = "final"
+                block["pattern"] = current_pattern
+                blocks.append(block)
+            line = line.split()
+            assert len(line) == 2
+            assert line[1] not in patterns
+            current_pattern = line[1]
+            patterns[current_pattern] = len(blocks)
+            state_types[current_pattern] = dict()
+            udata_types[current_pattern] = dict()
+            continue
+
+        assert current_pattern is not None
+
         if cmd == "state":
             m = re.match(r"^state\s+<(.*?)>\s+(([A-Za-z_][A-Za-z_0-9]*\s+)*[A-Za-z_][A-Za-z_0-9]*)\s*$", line)
             assert m
             type_str = m.group(1)
             states_str = m.group(2)
             for s in re.split(r"\s+", states_str):
-                assert s not in state_types
-                state_types[s] = type_str
+                assert s not in state_types[current_pattern]
+                state_types[current_pattern][s] = type_str
             continue
 
         if cmd == "udata":
@@ -123,19 +146,20 @@ def process_pmgfile(f):
             type_str = m.group(1)
             udatas_str = m.group(2)
             for s in re.split(r"\s+", udatas_str):
-                assert s not in udata_types
-                udata_types[s] = type_str
+                assert s not in udata_types[current_pattern]
+                udata_types[current_pattern][s] = type_str
             continue
 
         if cmd == "match":
             block = dict()
             block["type"] = "match"
+            block["pattern"] = current_pattern
 
             line = line.split()
             assert len(line) == 2
-            assert line[1] not in state_types
+            assert line[1] not in state_types[current_pattern]
             block["cell"] = line[1]
-            state_types[line[1]] = "Cell*";
+            state_types[current_pattern][line[1]] = "Cell*";
 
             block["if"] = list()
             block["select"] = list()
@@ -178,15 +202,18 @@ def process_pmgfile(f):
                 assert False
 
             blocks.append(block)
+            continue
 
         if cmd == "code":
             block = dict()
             block["type"] = "code"
+            block["pattern"] = current_pattern
+
             block["code"] = list()
             block["states"] = set()
 
             for s in line.split()[1:]:
-                assert s in state_types
+                assert s in state_types[current_pattern]
                 block["states"].add(s)
 
             while True:
@@ -199,11 +226,25 @@ def process_pmgfile(f):
                 block["code"].append(rewrite_cpp(l.rstrip()))
 
             blocks.append(block)
+            continue
+
+        assert False
 
 for fn in pmgfiles:
     with open(fn, "r") as f:
         process_pmgfile(f)
 
+if current_pattern is not None:
+    block = dict()
+    block["type"] = "final"
+    block["pattern"] = current_pattern
+    blocks.append(block)
+
+current_pattern = None
+
+if debug:
+    pp.pprint(blocks)
+
 with open(outfile, "w") as f:
     print("// Generated by pmgen.py from {}.pgm".format(prefix), file=f)
     print("", file=f)
@@ -236,17 +277,19 @@ with open(outfile, "w") as f:
     print("  int rollback;", file=f)
     print("", file=f)
 
-    print("  struct state_t {", file=f)
-    for s, t in sorted(state_types.items()):
-        print("    {} {};".format(t, s), file=f)
-    print("  } st;", file=f)
-    print("", file=f)
+    for current_pattern in sorted(patterns.keys()):
+        print("  struct state_{}_t {{".format(current_pattern), file=f)
+        for s, t in sorted(state_types[current_pattern].items()):
+            print("    {} {};".format(t, s), file=f)
+        print("  }} st_{};".format(current_pattern), file=f)
+        print("", file=f)
 
-    print("  struct udata_t {", file=f)
-    for s, t in sorted(udata_types.items()):
-        print("    {} {};".format(t, s), file=f)
-    print("  } ud;", file=f)
-    print("", file=f)
+        print("  struct udata_{}_t {{".format(current_pattern), file=f)
+        for s, t in sorted(udata_types[current_pattern].items()):
+            print("    {} {};".format(t, s), file=f)
+        print("  }} ud_{};".format(current_pattern), file=f)
+        print("", file=f)
+    current_pattern = None
 
     for v, n in sorted(ids.items()):
         if n[0] == "\\":
@@ -282,20 +325,22 @@ with open(outfile, "w") as f:
     print("  }", file=f)
     print("", file=f)
 
-    print("  void check_blacklist() {", file=f)
-    print("    if (!blacklist_dirty)", file=f)
-    print("      return;", file=f)
-    print("    blacklist_dirty = false;", file=f)
-    for index in range(len(blocks)):
-        block = blocks[index]
-        if block["type"] == "match":
-            print("    if (st.{} != nullptr && blacklist_cells.count(st.{})) {{".format(block["cell"], block["cell"]), file=f)
-            print("      rollback = {};".format(index+1), file=f)
-            print("      return;", file=f)
-            print("    }", file=f)
-    print("    rollback = 0;", file=f)
-    print("  }", file=f)
-    print("", file=f)
+    for current_pattern in sorted(patterns.keys()):
+        print("  void check_blacklist_{}() {{".format(current_pattern), file=f)
+        print("    if (!blacklist_dirty)", file=f)
+        print("      return;", file=f)
+        print("    blacklist_dirty = false;", file=f)
+        for index in range(len(blocks)):
+            block = blocks[index]
+            if block["type"] == "match":
+                print("    if (st_{}.{} != nullptr && blacklist_cells.count(st_{}.{})) {{".format(current_pattern, block["cell"], current_pattern, block["cell"]), file=f)
+                print("      rollback = {};".format(index+1), file=f)
+                print("      return;", file=f)
+                print("    }", file=f)
+        print("    rollback = 0;", file=f)
+        print("  }", file=f)
+        print("", file=f)
+    current_pattern = None
 
     print("  SigSpec port(Cell *cell, IdString portname) {", file=f)
     print("    return sigmap(cell->getPort(portname));", file=f)
@@ -318,11 +363,13 @@ with open(outfile, "w") as f:
 
     print("  {}_pm(Module *module, const vector<Cell*> &cells) :".format(prefix), file=f)
     print("      module(module), sigmap(module) {", file=f)
-    for s, t in sorted(udata_types.items()):
-        if t.endswith("*"):
-            print("    ud.{} = nullptr;".format(s), file=f)
-        else:
-            print("    ud.{} = {}();".format(s, t), file=f)
+    for current_pattern in sorted(patterns.keys()):
+        for s, t in sorted(udata_types[current_pattern].items()):
+            if t.endswith("*"):
+                print("    ud_{}.{} = nullptr;".format(current_pattern,s), file=f)
+            else:
+                print("    ud_{}.{} = {}();".format(current_pattern, s, t), file=f)
+    current_pattern = None
     print("    for (auto cell : module->cells()) {", file=f)
     print("      for (auto &conn : cell->connections())", file=f)
     print("        add_siguser(conn.second, cell);", file=f)
@@ -352,34 +399,48 @@ with open(outfile, "w") as f:
     print("  }", file=f)
     print("", file=f)
 
-    print("  void run(std::function<void()> on_accept_f) {", file=f)
-    print("    on_accept = on_accept_f;", file=f)
-    print("    rollback = 0;", file=f)
-    print("    blacklist_dirty = false;", file=f)
-    for s, t in sorted(state_types.items()):
-        if t.endswith("*"):
-            print("    st.{} = nullptr;".format(s), file=f)
-        else:
-            print("    st.{} = {}();".format(s, t), file=f)
-    print("    block_0();", file=f)
-    print("  }", file=f)
-    print("", file=f)
-
-    print("  void run(std::function<void({}_pm&)> on_accept_f) {{".format(prefix), file=f)
-    print("    run([&](){on_accept_f(*this);});", file=f)
-    print("  }", file=f)
-    print("", file=f)
+    for current_pattern in sorted(patterns.keys()):
+        print("  void run_{}(std::function<void()> on_accept_f) {{".format(current_pattern), file=f)
+        print("    on_accept = on_accept_f;", file=f)
+        print("    rollback = 0;", file=f)
+        print("    blacklist_dirty = false;", file=f)
+        for s, t in sorted(state_types[current_pattern].items()):
+            if t.endswith("*"):
+                print("    st_{}.{} = nullptr;".format(current_pattern, s), file=f)
+            else:
+                print("    st_{}.{} = {}();".format(current_pattern, s, t), file=f)
+        print("    block_{}();".format(patterns[current_pattern]), file=f)
+        print("  }", file=f)
+        print("", file=f)
+        print("  void run_{}(std::function<void({}_pm&)> on_accept_f) {{".format(current_pattern, prefix), file=f)
+        print("    run_{}([&](){{on_accept_f(*this);}});".format(current_pattern), file=f)
+        print("  }", file=f)
+        print("", file=f)
+        print("  void run_{}(std::function<void(state_{}_t&)> on_accept_f) {{".format(current_pattern, current_pattern), file=f)
+        print("    run_{}([&](){{on_accept_f(st_{});}});".format(current_pattern, current_pattern), file=f)
+        print("  }", file=f)
+        print("", file=f)
+    current_pattern = None
 
     for index in range(len(blocks)):
         block = blocks[index]
 
         print("  void block_{}() {{".format(index), file=f)
+        current_pattern = block["pattern"]
+
+        if block["type"] == "final":
+            print("    on_accept();", file=f)
+            print("    check_blacklist_{}();".format(current_pattern), file=f)
+            print("  }", file=f)
+            if index+1 != len(blocks):
+                print("", file=f)
+            continue
 
         const_st = set()
         nonconst_st = set()
         restore_st = set()
 
-        for i in range(index):
+        for i in range(patterns[current_pattern], index):
             if blocks[i]["type"] == "code":
                 for s in blocks[i]["states"]:
                     const_st.add(s)
@@ -402,27 +463,27 @@ with open(outfile, "w") as f:
             assert False
 
         for s in sorted(const_st):
-            t = state_types[s]
+            t = state_types[current_pattern][s]
             if t.endswith("*"):
-                print("    {} const &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f)
+                print("    {} const &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
             else:
-                print("    const {} &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f)
+                print("    const {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
 
         for s in sorted(nonconst_st):
-            t = state_types[s]
-            print("    {} &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f)
+            t = state_types[current_pattern][s]
+            print("    {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
 
         if len(restore_st):
             print("", file=f)
             for s in sorted(restore_st):
-                t = state_types[s]
+                t = state_types[current_pattern][s]
                 print("    {} backup_{} = {};".format(t, s, s), file=f)
 
         if block["type"] == "code":
             print("", file=f)
             print("    do {", file=f)
-            print("#define reject do { check_blacklist(); goto rollback_label; } while(0)", file=f)
-            print("#define accept do { on_accept(); check_blacklist(); if (rollback) goto rollback_label; } while(0)", file=f)
+            print("#define reject do {{ check_blacklist_{}(); goto rollback_label; }} while(0)".format(current_pattern), file=f)
+            print("#define accept do {{ on_accept(); check_blacklist_{}(); if (rollback) goto rollback_label; }} while(0)".format(current_pattern), file=f)
             print("#define branch do {{ block_{}(); if (rollback) goto rollback_label; }} while(0)".format(index+1), file=f)
 
             for line in block["code"]:
@@ -441,11 +502,11 @@ with open(outfile, "w") as f:
             if len(restore_st) or len(nonconst_st):
                 print("", file=f)
                 for s in sorted(restore_st):
-                    t = state_types[s]
+                    t = state_types[current_pattern][s]
                     print("    {} = backup_{};".format(s, s), file=f)
                 for s in sorted(nonconst_st):
                     if s not in restore_st:
-                        t = state_types[s]
+                        t = state_types[current_pattern][s]
                         if t.endswith("*"):
                             print("    {} = nullptr;".format(s), file=f)
                         else:
@@ -494,17 +555,10 @@ with open(outfile, "w") as f:
         else:
             assert False
 
-
+        current_pattern = None
         print("  }", file=f)
         print("", file=f)
 
-    print("  void block_{}() {{".format(len(blocks)), file=f)
-    print("    on_accept();", file=f)
-    print("    check_blacklist();", file=f)
-    print("  }", file=f)
     print("};", file=f)
-
     print("", file=f)
     print("YOSYS_NAMESPACE_END", file=f)
-
-# pp.pprint(blocks)