Refactor pmgen rollback mechanism
authorClifford Wolf <clifford@clifford.at>
Sat, 17 Aug 2019 11:54:18 +0000 (13:54 +0200)
committerClifford Wolf <clifford@clifford.at>
Sat, 17 Aug 2019 11:54:18 +0000 (13:54 +0200)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/pmgen/pmgen.py

index 8401e1295e63da09cf1a540e144e3f22eaf728a3..18c3bf5a5b679554cb708590250fdba74cec335f 100644 (file)
@@ -362,8 +362,7 @@ with open(outfile, "w") as f:
     print("  dict<SigBit, pool<Cell*>> sigusers;", file=f)
     print("  pool<Cell*> blacklist_cells;", file=f)
     print("  pool<Cell*> autoremove_cells;", file=f)
-    print("  bool blacklist_dirty;", file=f)
-    print("  vector<pair<Cell*,int>> rollback_stack;", file=f)
+    print("  dict<Cell*,int> rollback_cache;", file=f)
     print("  int rollback;", file=f)
     print("", file=f)
 
@@ -399,31 +398,20 @@ with open(outfile, "w") as f:
     print("", file=f)
 
     print("  void blacklist(Cell *cell) {", file=f)
-    print("    if (cell != nullptr) {", file=f)
-    print("      if (blacklist_cells.insert(cell).second)", file=f)
-    print("        blacklist_dirty = true;", file=f)
+    print("    if (cell != nullptr && blacklist_cells.insert(cell).second) {", file=f)
+    print("      auto ptr = rollback_cache.find(cell);", file=f)
+    print("      if (ptr == rollback_cache.end()) return;", file=f)
+    print("      int rb = ptr->second;", file=f)
+    print("      if (rollback == 0 || rollback > rb)", file=f)
+    print("        rollback = rb;", file=f)
     print("    }", file=f)
     print("  }", file=f)
     print("", file=f)
 
-    print("  void check_blacklist() {", file=f)
-    print("    if (!blacklist_dirty)", file=f)
-    print("      return;", file=f)
-    print("    blacklist_dirty = false;", file=f)
-    print("    for (int i = 0; i < GetSize(rollback_stack); i++)", file=f)
-    print("      if (blacklist_cells.count(rollback_stack[i].first)) {", file=f)
-    print("        rollback = rollback_stack[i].second;", file=f)
-    print("        rollback_stack.resize(i);", file=f)
-    print("        return;", file=f)
-    print("      }", file=f)
-    print("  }", file=f)
-    print("", file=f)
-
     print("  void autoremove(Cell *cell) {", file=f)
     print("    if (cell != nullptr) {", file=f)
-    print("      if (blacklist_cells.insert(cell).second)", file=f)
-    print("        blacklist_dirty = true;", file=f)
     print("      autoremove_cells.insert(cell);", file=f)
+    print("      blacklist(cell);", file=f)
     print("    }", file=f)
     print("  }", file=f)
     print("", file=f)
@@ -492,14 +480,13 @@ with open(outfile, "w") as f:
         print("    accept_cnt = 0;", file=f)
         print("    on_accept = on_accept_f;", file=f)
         print("    rollback = 0;", file=f)
-        print("    blacklist_dirty = false;", file=f)
         for s, t in sorted(state_types[current_pattern].items()):
             if t.endswith("*"):
                 print("    st_{}.{} = nullptr;".format(current_pattern, s), file=f)
             else:
                 print("    st_{}.{} = {}();".format(current_pattern, s, t), file=f)
         print("    block_{}(1);".format(patterns[current_pattern]), file=f)
-        print("    log_assert(rollback_stack.empty());", file=f)
+        print("    log_assert(rollback_cache.empty());", file=f)
         print("    return accept_cnt;", file=f)
         print("  }", file=f)
         print("", file=f)
@@ -587,9 +574,9 @@ with open(outfile, "w") as f:
 
         if block["type"] == "code":
             print("", file=f)
-            print("#define reject do { check_blacklist(); goto rollback_label; } while(0)", file=f)
-            print("#define accept do { accept_cnt++; on_accept(); check_blacklist(); if (rollback) goto rollback_label; } while(0)", file=f)
-            print("#define finish do { rollback = -1; rollback_stack.clean(); goto rollback_label; } while(0)", file=f)
+            print("#define reject do { goto rollback_label; } while(0)", file=f)
+            print("#define accept do { accept_cnt++; on_accept(); if (rollback) goto rollback_label; } while(0)", file=f)
+            print("#define finish do { rollback = -1; goto rollback_label; } while(0)", file=f)
             print("#define branch do {{ block_{}(recursion+1); if (rollback) goto rollback_label; }} while(0)".format(index+1), file=f)
             print("#define subpattern(pattern_name) do {{ block_subpattern_{}_ ## pattern_name (recursion+1); if (rollback) goto rollback_label; }} while(0)".format(current_pattern), file=f)
 
@@ -610,10 +597,12 @@ with open(outfile, "w") as f:
             print("    YS_ATTRIBUTE(unused);", file=f)
 
             if len(block["fcode"]):
-                print("#define accept do { accept_cnt++; on_accept(); check_blacklist(); } while(0)", file=f)
-                print("#define finish do { rollback = -1; rollback_stack.clean(); return; } while(0)", file=f)
+                print("#define accept do { accept_cnt++; on_accept(); } while(0)", file=f)
+                print("#define finish do { rollback = -1; goto finish_label; } while(0)", file=f)
                 for line in block["fcode"]:
                     print("  " + line, file=f)
+                print("finish_label:", file=f)
+                print("    YS_ATTRIBUTE(unused);", file=f)
                 print("#undef accept", file=f)
                 print("#undef finish", file=f)
 
@@ -664,11 +653,11 @@ with open(outfile, "w") as f:
                 print("        if (!({})) continue;".format(expr), file=f)
             if block["semioptional"] or block["genargs"] is not None:
                 print("        found_any_match = true;", file=f)
-            print("        rollback_stack.push_back(make_pair(cells[idx], recursion));", file=f)
+            print("        auto rollback_ptr = rollback_cache.insert(make_pair(cells[idx], recursion));", file=f)
             print("        block_{}(recursion+1);".format(index+1), file=f)
-            print("        if (rollback == 0) {", file=f)
-            print("          rollback_stack.pop_back();", file=f)
-            print("        } else {", file=f)
+            print("        if (rollback_ptr.second)", file=f)
+            print("          rollback_cache.erase(rollback_ptr.first);", file=f)
+            print("        if (rollback) {", file=f)
             print("          if (rollback != recursion) {{".format(index+1), file=f)
             print("            {} = backup_{};".format(block["cell"], block["cell"]), file=f)
             print("            return;", file=f)
@@ -690,7 +679,7 @@ with open(outfile, "w") as f:
             print("    {} = backup_{};".format(block["cell"], block["cell"]), file=f)
 
             if block["genargs"] is not None:
-                print("#define finish do { rollback = -1; rollback_stack.clean(); return; } while(0)", file=f)
+                print("#define finish do { rollback = -1; return; } while(0)", file=f)
                 print("    if (generate_mode && !found_any_match) {", file=f)
                 if len(block["genargs"]) == 1:
                     print("    if (rng(100) >= {}) return;".format(block["genargs"][0]), file=f)