verilog: split comb block, use assign statements
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Sat, 7 Jan 2012 11:19:06 +0000 (12:19 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Sat, 7 Jan 2012 11:19:06 +0000 (12:19 +0100)
migen/fhdl/tools.py
migen/fhdl/verilog.py

index 0202f4f34705500505ae9404f76ce6922d587622..870e8f83e71a90835fd4cfcbd5bf03730a1534b9 100644 (file)
@@ -63,8 +63,6 @@ def list_targets(node):
        elif isinstance(node, Cat):
                l = list(map(list_targets, node.l))
                return set().union(*l)
-       elif isinstance(node, Replicate):
-               return list_targets(node.v)
        elif isinstance(node, _Assign):
                return list_targets(node.l)
        elif isinstance(node, _StatementList):
@@ -80,6 +78,21 @@ def list_targets(node):
        else:
                raise TypeError
 
+def group_by_targets(sl):
+       groups = []
+       for statement in sl.l:
+               targets = list_targets(statement)
+               processed = False
+               for g in groups:
+                       if not targets.isdisjoint(g[0]):
+                               g[0].update(targets)
+                               g[1].append(statement)
+                               processed = True
+                               break
+               if not processed:
+                       groups.append((targets, [statement]))
+       return groups
+
 def list_inst_outs(i):
        if isinstance(i, Fragment):
                return list_inst_outs(i.instances)
index b98f3ea8dedbba62fa3aed5c6e887cccf0fc7337..8161037abef06ce0e6a56f36421ff327c992e598 100644 (file)
@@ -46,42 +46,57 @@ def _printexpr(ns, node):
        else:
                raise TypeError
 
-def _printnode(ns, is_sync, level, node):
+(_AT_BLOCKING, _AT_NONBLOCKING, _AT_SIGNAL) = range(3)
+
+def _printnode(ns, at, level, node):
        if isinstance(node, _Assign):
-               if is_sync and is_variable(node.l):
+               if at == _AT_BLOCKING:
+                       assignment = " = "
+               elif at == _AT_NONBLOCKING:
+                       assignment = " <= "
+               elif is_variable(node.l):
                        assignment = " = "
                else:
                        assignment = " <= "
                return "\t"*level + _printexpr(ns, node.l) + assignment + _printexpr(ns, node.r) + ";\n"
        elif isinstance(node, _StatementList):
-               return "".join(list(map(partial(_printnode, ns, is_sync, level), node.l)))
+               return "".join(list(map(partial(_printnode, ns, at, level), node.l)))
        elif isinstance(node, If):
                r = "\t"*level + "if (" + _printexpr(ns, node.cond) + ") begin\n"
-               r += _printnode(ns, is_sync, level + 1, node.t)
+               r += _printnode(ns, at, level + 1, node.t)
                if node.f.l:
                        r += "\t"*level + "end else begin\n"
-                       r += _printnode(ns, is_sync, level + 1, node.f)
+                       r += _printnode(ns, at, level + 1, node.f)
                r += "\t"*level + "end\n"
                return r
        elif isinstance(node, Case):
                r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
                for case in node.cases:
                        r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n"
-                       r += _printnode(ns, is_sync, level + 2, case[1])
+                       r += _printnode(ns, at, level + 2, case[1])
                        r += "\t"*(level + 1) + "end\n"
                if node.default.l:
                        r += "\t"*(level + 1) + "default: begin\n"
-                       r += _printnode(ns, is_sync, level + 2, node.default)
+                       r += _printnode(ns, at, level + 2, node.default)
                        r += "\t"*(level + 1) + "end\n"
                r += "\t"*level + "endcase\n"
                return r
        else:
                raise TypeError
 
+def _list_comb_wires(f):
+       r = set()
+       groups = group_by_targets(f.comb)
+       for g in groups:
+               if len(g[1]) == 1 and isinstance(g[1][0], _Assign):
+                       r |= g[0]
+       return r
+
 def _printheader(f, ios, name, ns):
        sigs = list_signals(f)
        targets = list_targets(f)
        instouts = list_inst_outs(f)
+       wires = _list_comb_wires(f)
        r = "module " + name + "(\n"
        firstp = True
        for sig in ios:
@@ -89,14 +104,17 @@ def _printheader(f, ios, name, ns):
                        r += ",\n"
                firstp = False
                if sig in targets:
-                       r += "\toutput reg " + _printsig(ns, sig)
+                       if sig in wires:
+                               r += "\toutput " + _printsig(ns, sig)
+                       else:
+                               r += "\toutput reg " + _printsig(ns, sig)
                elif sig in instouts:
                        r += "\toutput " + _printsig(ns, sig)
                else:
                        r += "\tinput " + _printsig(ns, sig)
        r += "\n);\n\n"
        for sig in sigs - ios:
-               if sig in instouts:
+               if sig in wires or sig in instouts:
                        r += "wire " + _printsig(ns, sig) + ";\n"
                else:
                        r += "reg " + _printsig(ns, sig) + ";\n"
@@ -110,37 +128,39 @@ def _printcomb(f, ns):
                # to run the combinatorial process once at the beginning.
                syn_off = "// synthesis translate off\n"
                syn_on = "// synthesis translate on\n"
-               dummy_s = Signal(name="dummy_s")
-               dummy_d = Signal(name="dummy_d")
+               dummy_s = Signal()
                r += syn_off
                r += "reg " + _printsig(ns, dummy_s) + ";\n"
-               r += "reg " + _printsig(ns, dummy_d) + ";\n"
                r += "initial " + ns.get_name(dummy_s) + " <= 1'b0;\n"
-               r += syn_on + "\n"
-               
-               r += "always @(*) begin\n"
-               to_reset = list_targets(f.comb)
-               # do not reset signals with obvious unconditional assignments
-               for s in f.comb.l:
-                       if isinstance(s, _Assign) and isinstance(s.l, Signal):
-                               try:
-                                       to_reset.remove(s.l)
-                               except KeyError:
-                                       pass
-               for t in to_reset:
-                       r += "\t" + ns.get_name(t) + " <= " + str(t.reset) + ";\n"
-               r += _printnode(ns, False, 1, f.comb)
-               r += syn_off
-               r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n"
                r += syn_on
-               r += "end\n\n"
+               
+               groups = group_by_targets(f.comb)
+               
+               for g in groups:
+                       if len(g[1]) == 1 and isinstance(g[1][0], _Assign):
+                               r += "assign " + _printnode(ns, _AT_BLOCKING, 0, g[1][0])
+                       else:
+                               dummy_d = Signal()
+                               r += "\n" + syn_off
+                               r += "reg " + _printsig(ns, dummy_d) + ";\n"
+                               r += syn_on
+                               
+                               r += "always @(*) begin\n"
+                               for t in g[0]:
+                                       r += "\t" + ns.get_name(t) + " <= " + str(t.reset) + ";\n"
+                               r += _printnode(ns, _AT_NONBLOCKING, 1, _StatementList(g[1]))
+                               r += syn_off
+                               r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n"
+                               r += syn_on
+                               r += "end\n"
+       r += "\n"
        return r
 
 def _printsync(f, ns, clk_signal, rst_signal):
        r = ""
        if f.sync.l:
                r += "always @(posedge " + ns.get_name(clk_signal) + ") begin\n"
-               r += _printnode(ns, True, 1, insert_reset(rst_signal, f.sync))
+               r += _printnode(ns, _AT_SIGNAL, 1, insert_reset(rst_signal, f.sync))
                r += "end\n\n"
        return r