Improvements in pmgen for recursive patterns
authorClifford Wolf <clifford@clifford.at>
Thu, 15 Aug 2019 16:34:36 +0000 (18:34 +0200)
committerClifford Wolf <clifford@clifford.at>
Thu, 15 Aug 2019 16:35:56 +0000 (18:35 +0200)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
Makefile
passes/pmgen/Makefile.inc
passes/pmgen/README.md
passes/pmgen/peepopt_shiftmul.pmg
passes/pmgen/pmgen.py

index 95b5d451beb26b5d83899f1ca6ff6830caba165e..db8915225abc8a7e5893cd66b7e792bfba693649 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -487,6 +487,11 @@ define add_include_file
 $(eval $(call add_share_file,$(dir share/include/$(1)),$(1)))
 endef
 
+define add_extra_objs
+EXTRA_OBJS += $(1)
+.SECONDARY: $(1)
+endef
+
 ifeq ($(PRETTY), 1)
 P_STATUS = 0
 P_OFFSET = 0
index 7911132db44d6c5f94d7b105a9b1740b2623c61f..0d977635bc88a6de16beebbe1209eb93b4c59762 100644 (file)
@@ -1,20 +1,17 @@
-OBJS += passes/pmgen/ice40_dsp.o
-OBJS += passes/pmgen/peepopt.o
+%_pm.h: passes/pmgen/pmgen.py %.pmg
+       $(P) mkdir -p passes/pmgen && python3 $< -o $@ -p $(subst _pm.h,,$(notdir $@)) $(filter-out $<,$^)
 
 # --------------------------------------
 
+OBJS += passes/pmgen/ice40_dsp.o
 passes/pmgen/ice40_dsp.o: passes/pmgen/ice40_dsp_pm.h
-EXTRA_OBJS += passes/pmgen/ice40_dsp_pm.h
-.SECONDARY: passes/pmgen/ice40_dsp_pm.h
-
-passes/pmgen/ice40_dsp_pm.h: passes/pmgen/pmgen.py passes/pmgen/ice40_dsp.pmg
-       $(P) mkdir -p passes/pmgen && python3 $< -o $@ -p ice40_dsp $(filter-out $<,$^)
+$(eval $(call add_extra_objs,passes/pmgen/ice40_dsp_pm.h))
 
 # --------------------------------------
 
+OBJS += passes/pmgen/peepopt.o
 passes/pmgen/peepopt.o: passes/pmgen/peepopt_pm.h
-EXTRA_OBJS += passes/pmgen/peepopt_pm.h
-.SECONDARY: passes/pmgen/peepopt_pm.h
+$(eval $(call add_extra_objs,passes/pmgen/peepopt_pm.h))
 
 PEEPOPT_PATTERN  = passes/pmgen/peepopt_shiftmul.pmg
 PEEPOPT_PATTERN += passes/pmgen/peepopt_muldiv.pmg
index 2f0b1fd5a489bcd6149b88337fd2dcf5e88c1ca6..db722c818a2a43fe44ced75927e9aafbd5a46cc0 100644 (file)
@@ -232,5 +232,28 @@ But in some cases it is more natural to utilize the implicit branch statement:
         portAB = \B;
     endcode
 
-There is an implicit `code..endcode` block at the end of each `.pmg` file
+There is an implicit `code..endcode` block at the end of each (sub)pattern
 that just accepts everything that gets all the way there.
+
+A `code..finally..endcode` block executes the code after `finally` during
+back-tracking. This is useful for maintaining user data state or printing
+debug messages. For example:
+
+    udata <vector<Cell*>> stack
+
+    code
+        stack.push_back(addAB);
+    finally
+        stack.pop_back();
+    endcode
+
+Declaring a subpattern
+----------------------
+
+A subpattern starts with a line containing the `subpattern` keyword followed
+by the name of the subpattern. Subpatterns can be called from a `code` block
+using a `subpattern(<subpattern_name>);` C statement.
+
+Arguments may be passed to subpattern via state variables. The `subpattern`
+line must be followed by a `arg <arg1> <arg2> ...` line that lists the
+state variables used to pass arguments. Subpatterns allow recursion.
index 6adab4e5f1dc43ae3f8cb81e82be17988371da8a..d766d9e4ade022ee91d3f9612fb7a6d91d27258d 100644 (file)
@@ -34,6 +34,7 @@ match mul
 endmatch
 
 code
+{
        IdString const_factor_port = port(mul, \A).is_fully_const() ? \A : \B;
        IdString const_factor_signed = const_factor_port == \A ? \A_SIGNED : \B_SIGNED;
        Const const_factor_cnst = port(mul, const_factor_port).as_const();
@@ -91,4 +92,5 @@ code
 
        blacklist(shift);
        reject;
+}
 endcode
index 81052afce0509b62c38f22469f8482c8acd17fe1..22a7a5225367943f641b80e465e614a5811a81f1 100644 (file)
@@ -38,7 +38,10 @@ for a in args:
 assert prefix is not None
 
 current_pattern = None
+current_subpattern = None
 patterns = dict()
+subpatterns = dict()
+subpattern_args = dict()
 state_types = dict()
 udata_types = dict()
 blocks = list()
@@ -104,9 +107,12 @@ def rewrite_cpp(s):
 
     return "".join(t)
 
-def process_pmgfile(f):
+def process_pmgfile(f, filename):
+    linenr = 0
     global current_pattern
+    global current_subpattern
     while True:
+        linenr += 1
         line = f.readline()
         if line == "": break
         line = line.strip()
@@ -119,19 +125,41 @@ def process_pmgfile(f):
             if current_pattern is not None:
                 block = dict()
                 block["type"] = "final"
-                block["pattern"] = current_pattern
+                block["pattern"] = (current_pattern, current_subpattern)
                 blocks.append(block)
             line = line.split()
             assert len(line) == 2
             assert line[1] not in patterns
             current_pattern = line[1]
+            current_subpattern = ""
             patterns[current_pattern] = len(blocks)
+            subpatterns[(current_pattern, current_subpattern)] = len(blocks)
+            subpattern_args[(current_pattern, current_subpattern)] = list()
             state_types[current_pattern] = dict()
             udata_types[current_pattern] = dict()
             continue
 
         assert current_pattern is not None
 
+        if cmd == "subpattern":
+            block = dict()
+            block["type"] = "final"
+            block["pattern"] = (current_pattern, current_subpattern)
+            blocks.append(block)
+            line = line.split()
+            assert len(line) == 2
+            current_subpattern = line[1]
+            subpattern_args[(current_pattern, current_subpattern)] = list()
+            assert (current_pattern, current_subpattern) not in subpatterns
+            subpatterns[(current_pattern, current_subpattern)] = len(blocks)
+            continue
+
+        if cmd == "arg":
+            line = line.split()
+            assert len(line) > 1
+            subpattern_args[(current_pattern, current_subpattern)] += line[1:]
+            continue
+
         if cmd == "state":
             m = re.match(r"^state\s+<(.*?)>\s+(([A-Za-z_][A-Za-z_0-9]*\s+)*[A-Za-z_][A-Za-z_0-9]*)\s*$", line)
             assert m
@@ -155,11 +183,12 @@ def process_pmgfile(f):
         if cmd == "match":
             block = dict()
             block["type"] = "match"
-            block["pattern"] = current_pattern
+            block["src"] = "%s:%d" % (filename, linenr)
+            block["pattern"] = (current_pattern, current_subpattern)
 
             line = line.split()
             assert len(line) == 2
-            assert line[1] not in state_types[current_pattern]
+            assert (line[1] not in state_types[current_pattern]) or (state_types[current_pattern][line[1]] == "Cell*")
             block["cell"] = line[1]
             state_types[current_pattern][line[1]] = "Cell*";
 
@@ -168,8 +197,10 @@ def process_pmgfile(f):
             block["index"] = list()
             block["filter"] = list()
             block["optional"] = False
+            block["semioptional"] = False
 
             while True:
+                linenr += 1
                 l = f.readline()
                 assert l != ""
                 a = l.split()
@@ -201,31 +232,47 @@ def process_pmgfile(f):
                     block["optional"] = True
                     continue
 
+                if a[0] == "semioptional":
+                    block["semioptional"] = True
+                    continue
+
                 assert False
 
+            if block["optional"]:
+                assert not block["semioptional"]
+
             blocks.append(block)
             continue
 
         if cmd == "code":
             block = dict()
             block["type"] = "code"
-            block["pattern"] = current_pattern
+            block["src"] = "%s:%d" % (filename, linenr)
+            block["pattern"] = (current_pattern, current_subpattern)
 
             block["code"] = list()
+            block["fcode"] = list()
             block["states"] = set()
 
             for s in line.split()[1:]:
                 assert s in state_types[current_pattern]
                 block["states"].add(s)
 
+            codetype = "code"
+
             while True:
+                linenr += 1
                 l = f.readline()
                 assert l != ""
                 a = l.split()
                 if len(a) == 0: continue
                 if a[0] == "endcode": break
 
-                block["code"].append(rewrite_cpp(l.rstrip()))
+                if a[0] == "finally":
+                    codetype = "fcode"
+                    continue
+
+                block[codetype].append(rewrite_cpp(l.rstrip()))
 
             blocks.append(block)
             continue
@@ -234,15 +281,16 @@ def process_pmgfile(f):
 
 for fn in pmgfiles:
     with open(fn, "r") as f:
-        process_pmgfile(f)
+        process_pmgfile(f, fn)
 
 if current_pattern is not None:
     block = dict()
     block["type"] = "final"
-    block["pattern"] = current_pattern
+    block["pattern"] = (current_pattern, current_subpattern)
     blocks.append(block)
 
 current_pattern = None
+current_subpattern = None
 
 if debug:
     pp.pprint(blocks)
@@ -335,7 +383,7 @@ with open(outfile, "w") as f:
         print("    blacklist_dirty = false;", file=f)
         for index in range(len(blocks)):
             block = blocks[index]
-            if block["pattern"] != current_pattern:
+            if block["pattern"] != (current_pattern, current_subpattern):
                 continue
             if block["type"] == "match":
                 print("    if (st_{}.{} != nullptr && blacklist_cells.count(st_{}.{})) {{".format(current_pattern, block["cell"], current_pattern, block["cell"]), file=f)
@@ -429,13 +477,21 @@ with open(outfile, "w") as f:
         print("    run_{}([](){{}});".format(current_pattern, current_pattern), file=f)
         print("  }", file=f)
         print("", file=f)
+
+    for p, s in sorted(subpatterns.keys()):
+        print("  void block_subpattern_{}_{}() {{ block_{}(); }}".format(p, s, subpatterns[(p, s)]), file=f)
+
     current_pattern = None
+    current_subpattern = None
 
     for index in range(len(blocks)):
         block = blocks[index]
 
+        if block["type"] in ("match", "code"):
+            print("  // {}".format(block["src"]), file=f)
+
         print("  void block_{}() {{".format(index), file=f)
-        current_pattern = block["pattern"]
+        current_pattern, current_subpattern = block["pattern"]
 
         if block["type"] == "final":
             print("    on_accept();", file=f)
@@ -449,7 +505,10 @@ with open(outfile, "w") as f:
         nonconst_st = set()
         restore_st = set()
 
-        for i in range(patterns[current_pattern], index):
+        for s in subpattern_args[(current_pattern, current_subpattern)]:
+            const_st.add(s)
+
+        for i in range(subpatterns[(current_pattern, current_subpattern)], index):
             if blocks[i]["type"] == "code":
                 for s in blocks[i]["states"]:
                     const_st.add(s)
@@ -482,6 +541,10 @@ with open(outfile, "w") as f:
             t = state_types[current_pattern][s]
             print("    {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
 
+        for u in sorted(udata_types[current_pattern].keys()):
+            t = udata_types[current_pattern][u]
+            print("    {} &{} YS_ATTRIBUTE(unused) = ud_{}.{};".format(t, u, current_pattern, u), file=f)
+
         if len(restore_st):
             print("", file=f)
             for s in sorted(restore_st):
@@ -490,24 +553,32 @@ with open(outfile, "w") as f:
 
         if block["type"] == "code":
             print("", file=f)
-            print("    do {", file=f)
             print("#define reject do {{ check_blacklist_{}(); goto rollback_label; }} while(0)".format(current_pattern), file=f)
             print("#define accept do {{ on_accept(); check_blacklist_{}(); if (rollback) goto rollback_label; }} while(0)".format(current_pattern), file=f)
             print("#define branch do {{ block_{}(); if (rollback) goto rollback_label; }} while(0)".format(index+1), file=f)
+            print("#define subpattern(pattern_name) block_subpattern_{}_ ## pattern_name ()".format(current_pattern), file=f)
 
             for line in block["code"]:
-                print("    " + line, file=f)
+                print("  " + line, file=f)
 
             print("", file=f)
-            print("      block_{}();".format(index+1), file=f)
+            print("    block_{}();".format(index+1), file=f)
+
             print("#undef reject", file=f)
             print("#undef accept", file=f)
             print("#undef branch", file=f)
-            print("    } while (0);", file=f)
+            print("#undef subpattern", file=f)
+
             print("", file=f)
             print("rollback_label:", file=f)
             print("    YS_ATTRIBUTE(unused);", file=f)
 
+            if len(block["fcode"]):
+                print("#define accept do {{ on_accept(); check_blacklist_{}(); }} while(0)".format(current_pattern), file=f)
+                for line in block["fcode"]:
+                    print("  " + line, file=f)
+                print("#undef accept", file=f)
+
             if len(restore_st) or len(nonconst_st):
                 print("", file=f)
                 for s in sorted(restore_st):
@@ -524,12 +595,15 @@ with open(outfile, "w") as f:
         elif block["type"] == "match":
             assert len(restore_st) == 0
 
+            print("    Cell* backup_{} = {};".format(block["cell"], block["cell"]), file=f)
+
             if len(block["if"]):
                 for expr in block["if"]:
                     print("", file=f)
                     print("    if (!({})) {{".format(expr), file=f)
                     print("      {} = nullptr;".format(block["cell"]), file=f)
                     print("      block_{}();".format(index+1), file=f)
+                    print("      {} = backup_{};".format(block["cell"], block["cell"]), file=f)
                     print("      return;", file=f)
                     print("    }", file=f)
 
@@ -539,16 +613,21 @@ with open(outfile, "w") as f:
                 print("    std::get<{}>(key) = {};".format(field, entry[2]), file=f)
             print("    const vector<Cell*> &cells = index_{}[key];".format(index), file=f)
 
+            if block["semioptional"]:
+                print("    bool found_semioptional_match = false;", file=f)
+
             print("", file=f)
             print("    for (int idx = 0; idx < GetSize(cells); idx++) {", file=f)
             print("      {} = cells[idx];".format(block["cell"]), file=f)
             print("      if (blacklist_cells.count({})) continue;".format(block["cell"]), file=f)
             for expr in block["filter"]:
                 print("      if (!({})) continue;".format(expr), file=f)
+            if block["semioptional"]:
+                print("      found_semioptional_match = true;", file=f)
             print("      block_{}();".format(index+1), file=f)
             print("      if (rollback) {", file=f)
             print("        if (rollback != {}) {{".format(index+1), file=f)
-            print("          {} = nullptr;".format(block["cell"]), file=f)
+            print("          {} = backup_{};".format(block["cell"], block["cell"]), file=f)
             print("          return;", file=f)
             print("        }", file=f)
             print("        rollback = 0;", file=f)
@@ -561,6 +640,11 @@ with open(outfile, "w") as f:
             if block["optional"]:
                 print("    block_{}();".format(index+1), file=f)
 
+            if block["semioptional"]:
+                print("    if (!found_semioptional_match) block_{}();".format(index+1), file=f)
+
+            print("    {} = backup_{};".format(block["cell"], block["cell"]), file=f)
+
         else:
             assert False