Add pmgen "generate" feature
authorClifford Wolf <clifford@clifford.at>
Fri, 16 Aug 2019 11:26:36 +0000 (13:26 +0200)
committerClifford Wolf <clifford@clifford.at>
Fri, 16 Aug 2019 11:26:36 +0000 (13:26 +0200)
Signed-off-by: Clifford Wolf <clifford@clifford.at>
passes/pmgen/pmgen.py
passes/pmgen/test_pmgen.cc
passes/pmgen/test_pmgen.pmg

index c6e9a9cbfc2735dc6234819cae6c0efc2e3109d5..e252c52f502f81987edb12d19225835dfd863c34 100644 (file)
@@ -186,6 +186,9 @@ def process_pmgfile(f, filename):
             block["src"] = "%s:%d" % (filename, linenr)
             block["pattern"] = (current_pattern, current_subpattern)
 
+            block["genargs"] = None
+            block["gencode"] = None
+
             line = line.split()
             assert len(line) == 2
             assert (line[1] not in state_types[current_pattern]) or (state_types[current_pattern][line[1]] == "Cell*")
@@ -236,6 +239,19 @@ def process_pmgfile(f, filename):
                     block["semioptional"] = True
                     continue
 
+                if a[0] == "generate":
+                    block["genargs"] = list([int(s) for s in a[1:]])
+                    block["gencode"] = list()
+                    assert len(block["genargs"]) < 2
+                    while True:
+                        linenr += 1
+                        l = f.readline()
+                        assert l != ""
+                        a = l.split()
+                        if a[0] == "endmatch": break
+                        block["gencode"].append(rewrite_cpp(l.rstrip()))
+                    break
+
                 assert False
 
             if block["optional"]:
@@ -310,7 +326,17 @@ with open(outfile, "w") as f:
     print("struct {}_pm {{".format(prefix), file=f)
     print("  Module *module;", file=f)
     print("  SigMap sigmap;", file=f)
-    print("  std::function<void()> on_accept;".format(prefix), file=f)
+    print("  std::function<void()> on_accept;", file=f)
+    print("  bool generate_mode;", file=f)
+    print("", file=f)
+
+    print("  uint32_t rngseed;", file=f)
+    print("  int rng(unsigned int n) {", file=f)
+    print("    rngseed ^= rngseed << 13;", file=f)
+    print("    rngseed ^= rngseed >> 17;", file=f)
+    print("    rngseed ^= rngseed << 5;", file=f)
+    print("    return rngseed % n;", file=f)
+    print("  }", file=f)
     print("", file=f)
 
     for index in range(len(blocks)):
@@ -415,7 +441,7 @@ with open(outfile, "w") as f:
     print("", file=f)
 
     print("  {}_pm(Module *module, const vector<Cell*> &cells) :".format(prefix), file=f)
-    print("      module(module), sigmap(module) {", file=f)
+    print("      module(module), sigmap(module), generate_mode(false), rngseed(12345678) {", file=f)
     for current_pattern in sorted(patterns.keys()):
         for s, t in sorted(udata_types[current_pattern].items()):
             if t.endswith("*"):
@@ -469,17 +495,15 @@ with open(outfile, "w") as f:
         print("    run_{}([&](){{on_accept_f(*this);}});".format(current_pattern), file=f)
         print("  }", file=f)
         print("", file=f)
-        print("  void run_{}(std::function<void(state_{}_t&)> on_accept_f) {{".format(current_pattern, current_pattern), file=f)
-        print("    run_{}([&](){{on_accept_f(st_{});}});".format(current_pattern, current_pattern), file=f)
-        print("  }", file=f)
-        print("", file=f)
         print("  void run_{}() {{".format(current_pattern), file=f)
         print("    run_{}([](){{}});".format(current_pattern, current_pattern), file=f)
         print("  }", file=f)
         print("", file=f)
 
-    for p, s in sorted(subpatterns.keys()):
-        print("  void block_subpattern_{}_{}() {{ block_{}(); }}".format(p, s, subpatterns[(p, s)]), file=f)
+    if len(subpatterns):
+        for p, s in sorted(subpatterns.keys()):
+            print("  void block_subpattern_{}_{}() {{ block_{}(); }}".format(p, s, subpatterns[(p, s)]), file=f)
+        print("", file=f)
 
     current_pattern = None
     current_subpattern = None
@@ -611,8 +635,8 @@ with open(outfile, "w") as f:
                 print("    std::get<{}>(key) = {};".format(field, entry[2]), file=f)
             print("    const vector<Cell*> &cells = index_{}[key];".format(index), file=f)
 
-            if block["semioptional"]:
-                print("    bool found_semioptional_match = false;", file=f)
+            if block["semioptional"] or block["genargs"] is not None:
+                print("    bool found_any_match = false;", file=f)
 
             print("", file=f)
             print("    for (int idx = 0; idx < GetSize(cells); idx++) {", file=f)
@@ -620,8 +644,8 @@ with open(outfile, "w") as f:
             print("      if (blacklist_cells.count({})) continue;".format(block["cell"]), file=f)
             for expr in block["filter"]:
                 print("      if (!({})) continue;".format(expr), file=f)
-            if block["semioptional"]:
-                print("      found_semioptional_match = true;", file=f)
+            if block["semioptional"] or block["genargs"] is not None:
+                print("      found_any_match = true;", file=f)
             print("      block_{}();".format(index+1), file=f)
             print("      if (rollback) {", file=f)
             print("        if (rollback != {}) {{".format(index+1), file=f)
@@ -639,10 +663,17 @@ with open(outfile, "w") as f:
                 print("    block_{}();".format(index+1), file=f)
 
             if block["semioptional"]:
-                print("    if (!found_semioptional_match) block_{}();".format(index+1), file=f)
+                print("    if (!found_any_match) block_{}();".format(index+1), file=f)
 
             print("    {} = backup_{};".format(block["cell"], block["cell"]), file=f)
 
+            if block["genargs"] is not None:
+                print("    if (generate_mode && !found_any_match) {", file=f)
+                if len(block["genargs"]) == 1:
+                    print("    if (rng(100) >= {}) return;".format(block["genargs"][0]), file=f)
+                for line in block["gencode"]:
+                    print("      " + line, file=f)
+                print("    }", file=f)
         else:
             assert False
 
index 3bff9ae12736eca5b668b4238774d1bb553495e3..48d58b263563581b71ded6a9807270f8352ab7fb 100644 (file)
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
 
+// for peepopt_pm
+bool did_something;
+
 #include "passes/pmgen/test_pmgen_pm.h"
+#include "passes/pmgen/ice40_dsp_pm.h"
+#include "passes/pmgen/peepopt_pm.h"
 
 void reduce_chain(test_pmgen_pm &pm)
 {
@@ -94,6 +99,100 @@ void reduce_tree(test_pmgen_pm &pm)
        log("    -> %s (%s)\n", log_id(c), log_id(c->type));
 }
 
+#define GENERATE_PATTERN(pmclass, pattern) \
+       generate_pattern<pmclass>([](pmclass &pm, std::function<void()> f){ return pm.run_ ## pattern(f); }, #pmclass, #pattern, design)
+
+void pmtest_addports(Module *module)
+{
+       pool<SigBit> driven_bits, used_bits;
+       SigMap sigmap(module);
+       int icnt = 0, ocnt = 0;
+
+       for (auto cell : module->cells())
+       for (auto conn : cell->connections())
+       {
+               if (cell->input(conn.first))
+                       for (auto bit : sigmap(conn.second))
+                               used_bits.insert(bit);
+               if (cell->output(conn.first))
+                       for (auto bit : sigmap(conn.second))
+                               driven_bits.insert(bit);
+       }
+
+       for (auto wire : vector<Wire*>(module->wires()))
+       {
+               SigSpec ibits, obits;
+               for (auto bit : sigmap(wire)) {
+                       if (!used_bits.count(bit))
+                               obits.append(bit);
+                       if (!driven_bits.count(bit))
+                               ibits.append(bit);
+               }
+               if (!ibits.empty()) {
+                       Wire *w = module->addWire(stringf("\\i%d", icnt++), GetSize(ibits));
+                       w->port_input = true;
+                       module->connect(ibits, w);
+               }
+               if (!obits.empty()) {
+                       Wire *w = module->addWire(stringf("\\o%d", ocnt++), GetSize(obits));
+                       w->port_output = true;
+                       module->connect(w, obits);
+               }
+       }
+
+       module->fixup_ports();
+}
+
+template <class pm>
+void generate_pattern(std::function<void(pm&,std::function<void()>)> run, const char *pmclass, const char *pattern, Design *design)
+{
+       log("Generating \"%s\" patterns for pattern matcher \"%s\".\n", pattern, pmclass);
+
+       int modcnt = 0;
+
+       while (modcnt < 100)
+       {
+               int submodcnt = 0, itercnt = 0, cellcnt = 0;
+               Module *mod = design->addModule(NEW_ID);
+
+               while (submodcnt < 10 && itercnt++ < 1000)
+               {
+                       pm matcher(mod, mod->cells());
+
+                       matcher.rng(1);
+                       matcher.rngseed += modcnt;
+                       matcher.rng(1);
+                       matcher.rngseed += submodcnt;
+                       matcher.rng(1);
+                       matcher.rngseed += itercnt;
+                       matcher.rng(1);
+                       matcher.rngseed += cellcnt;
+                       matcher.rng(1);
+
+                       if (GetSize(mod->cells()) != cellcnt)
+                       {
+                               bool found_match = false;
+                               run(matcher, [&](){ found_match = true; });
+
+                               if (found_match) {
+                                       Module *m = design->addModule(stringf("\\pmtest_%s_%s_%05d",
+                                                       pmclass, pattern, modcnt++));
+                                       mod->cloneInto(m);
+                                       pmtest_addports(m);
+                                       submodcnt++;
+                               }
+
+                               cellcnt = GetSize(mod->cells());
+                       }
+
+                       matcher.generate_mode = true;
+                       run(matcher, [](){});
+               }
+
+               design->remove(mod);
+       }
+}
+
 struct TestPmgenPass : public Pass {
        TestPmgenPass() : Pass("test_pmgen", "test pass for pmgen") { }
        void help() YS_OVERRIDE
@@ -104,11 +203,18 @@ struct TestPmgenPass : public Pass {
                log("\n");
                log("Demo for recursive pmgen patterns. Map chains of AND/OR/XOR to $reduce_*.\n");
                log("\n");
+
                log("\n");
                log("    test_pmgen -reduce_tree [options] [selection]\n");
                log("\n");
                log("Demo for recursive pmgen patterns. Map trees of AND/OR/XOR to $reduce_*.\n");
                log("\n");
+
+               log("\n");
+               log("    test_pmgen -generate [options] <pattern_name>\n");
+               log("\n");
+               log("Create modules that match the specified pattern.\n");
+               log("\n");
        }
 
        void execute_reduce_chain(std::vector<std::string> args, RTLIL::Design *design)
@@ -149,6 +255,40 @@ struct TestPmgenPass : public Pass {
                        test_pmgen_pm(module, module->selected_cells()).run_reduce(reduce_tree);
        }
 
+       void execute_generate(std::vector<std::string> args, RTLIL::Design *design)
+       {
+               log_header(design, "Executing TEST_PMGEN pass (-generate).\n");
+
+               size_t argidx;
+               for (argidx = 2; argidx < args.size(); argidx++)
+               {
+                       // if (args[argidx] == "-singleton") {
+                       //      singleton_mode = true;
+                       //      continue;
+                       // }
+                       break;
+               }
+
+               if (argidx+1 != args.size())
+                       log_cmd_error("Expected exactly one pattern.\n");
+
+               string pattern = args[argidx];
+
+               if (pattern == "reduce")
+                       return GENERATE_PATTERN(test_pmgen_pm, reduce);
+
+               if (pattern == "ice40_dsp")
+                       return GENERATE_PATTERN(ice40_dsp_pm, ice40_dsp);
+
+               if (pattern == "peepopt-muldiv")
+                       return GENERATE_PATTERN(peepopt_pm, muldiv);
+
+               if (pattern == "peepopt-shiftmul")
+                       return GENERATE_PATTERN(peepopt_pm, shiftmul);
+
+               log_cmd_error("Unkown pattern: %s\n", pattern.c_str());
+       }
+
        void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
        {
                if (GetSize(args) > 1)
@@ -157,6 +297,8 @@ struct TestPmgenPass : public Pass {
                                return execute_reduce_chain(args, design);
                        if (args[1] == "-reduce_tree")
                                return execute_reduce_tree(args, design);
+                       if (args[1] == "-generate")
+                               return execute_generate(args, design);
                }
                log_cmd_error("Missing or unsupported mode parameter.\n");
        }
index ccb37e553602d17b57eed2fef1cecec2fd0b1427..077d337d6b3a6699c5baf9c8444ef8047470cae0 100644 (file)
@@ -13,6 +13,22 @@ endcode
 match first
        select first->type.in($_AND_, $_OR_, $_XOR_)
        filter !non_first_cells.count(first)
+generate
+       SigSpec A = module->addWire(NEW_ID);
+       SigSpec B = module->addWire(NEW_ID);
+       SigSpec Y = module->addWire(NEW_ID);
+       switch (rng(3))
+       {
+       case 0:
+               module->addAndGate(NEW_ID, A, B, Y);
+               break;
+       case 1:
+               module->addOrGate(NEW_ID, A, B, Y);
+               break;
+       case 2:
+               module->addXorGate(NEW_ID, A, B, Y);
+               break;
+       }
 endmatch
 
 code
@@ -64,6 +80,12 @@ match next
        select next->type.in($_AND_, $_OR_, $_XOR_)
        index <IdString> next->type === chain.back().first->type
        index <SigSpec> port(next, \Y) === port(chain.back().first, chain.back().second)
+generate 50
+       SigSpec A = module->addWire(NEW_ID);
+       SigSpec B = module->addWire(NEW_ID);
+       SigSpec Y = port(chain.back().first, chain.back().second);
+       Cell *c = module->addAndGate(NEW_ID, A, B, Y);
+       c->type = chain.back().first->type;
 endmatch
 
 code