gen/fhdl/verilog: list available clock domains on keyerror
[litex.git] / litex / gen / fhdl / verilog.py
index a0fc3b165fbaaf5b53ef779705910e56606798d3..d5aef8d6d4f6ad0fbddf7cd20d9b95484ad272bc 100644 (file)
@@ -5,7 +5,6 @@ import collections
 from litex.gen.fhdl.structure import *
 from litex.gen.fhdl.structure import _Operator, _Slice, _Assign, _Fragment
 from litex.gen.fhdl.tools import *
-from litex.gen.fhdl.bitcontainer import bits_for
 from litex.gen.fhdl.namer import build_namespace
 from litex.gen.fhdl.conv_output import ConvOutput
 
@@ -28,7 +27,7 @@ _reserved_keywords = {
     "specify", "specparam", "strong0", "strong1", "supply0", "supply1",
     "table", "task", "time", "tran", "tranif0", "tranif1", "tri", "tri0",
     "tri1", "triand", "trior", "trireg", "unsigned", "use", "vectored", "wait",
-    "wand", "weak0", "weak1", "while", "wire", "wor","xnor", "xor"
+    "wand", "weak0", "weak1", "while", "wire", "wor","xnor", "xor", "do"
 }
 
 
@@ -116,11 +115,16 @@ def _printexpr(ns, node):
 (_AT_BLOCKING, _AT_NONBLOCKING, _AT_SIGNAL) = range(3)
 
 
-def _printnode(ns, at, level, node, target_filter=None):
-    if node is None:
-        return ""
-    elif target_filter is not None and target_filter not in list_targets(node):
-        return ""
+def _printnode(ns, at, level, node):
+    if isinstance(node, Display):
+        s = "\"" + node.s + "\\r\""
+        for arg in node.args:
+            s += ", "
+            if isinstance(arg, Signal):
+                s += ns.get_name(arg)
+            else:
+                s += str(arg)
+        return "\t"*level + "$display(" + s + ");\n"
     elif isinstance(node, _Assign):
         if at == _AT_BLOCKING:
             assignment = " = "
@@ -132,26 +136,27 @@ def _printnode(ns, at, level, node, target_filter=None):
             assignment = " <= "
         return "\t"*level + _printexpr(ns, node.l)[0] + assignment + _printexpr(ns, node.r)[0] + ";\n"
     elif isinstance(node, collections.Iterable):
-        return "".join(_printnode(ns, at, level, n, target_filter) for n in node)
+        return "".join(list(map(partial(_printnode, ns, at, level), node)))
     elif isinstance(node, If):
         r = "\t"*level + "if (" + _printexpr(ns, node.cond)[0] + ") begin\n"
-        r += _printnode(ns, at, level + 1, node.t, target_filter)
+        r += _printnode(ns, at, level + 1, node.t)
         if node.f:
             r += "\t"*level + "end else begin\n"
-            r += _printnode(ns, at, level + 1, node.f, target_filter)
+            r += _printnode(ns, at, level + 1, node.f)
         r += "\t"*level + "end\n"
         return r
     elif isinstance(node, Case):
         if node.cases:
             r = "\t"*level + "case (" + _printexpr(ns, node.test)[0] + ")\n"
-            css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
+            css = [(k, v) for k, v in node.cases.items() if isinstance(k, Constant)]
+            css = sorted(css, key=lambda x: x[0].value)
             for choice, statements in css:
                 r += "\t"*(level + 1) + _printexpr(ns, choice)[0] + ": begin\n"
-                r += _printnode(ns, at, level + 2, statements, target_filter)
+                r += _printnode(ns, at, level + 2, statements)
                 r += "\t"*(level + 1) + "end\n"
             if "default" in node.cases:
                 r += "\t"*(level + 1) + "default: begin\n"
-                r += _printnode(ns, at, level + 2, node.cases["default"], target_filter)
+                r += _printnode(ns, at, level + 2, node.cases["default"])
                 r += "\t"*(level + 1) + "end\n"
             r += "\t"*level + "endcase\n"
             return r
@@ -169,8 +174,30 @@ def _list_comb_wires(f):
             r |= g[0]
     return r
 
+def _printattr(sig, attr_translate):
+    r = ""
+    firsta = True
+    for attr in sorted(sig.attr,
+                       key=lambda x: ("", x) if isinstance(x, str) else x):
+        if isinstance(attr, tuple):
+            # platform-dependent attribute
+            attr_name, attr_value = attr
+        else:
+            # translated attribute
+            at = attr_translate[attr]
+            if at is None:
+                continue
+            attr_name, attr_value = at
+        if not firsta:
+            r += ", "
+        firsta = False
+        r += attr_name + " = \"" + attr_value + "\""
+    if r:
+        r = "(* " + r + " *)"
+    return r
+
 
-def _printheader(f, ios, name, ns,
+def _printheader(f, ios, name, ns, attr_translate,
                  reg_initialization):
     sigs = list_signals(f) | list_special_ios(f, True, True, True)
     special_outs = list_special_ios(f, False, True, True)
@@ -183,6 +210,9 @@ def _printheader(f, ios, name, ns,
         if not firstp:
             r += ",\n"
         firstp = False
+        attr = _printattr(sig, attr_translate)
+        if attr:
+            r += "\t" + attr
         if sig in inouts:
             r += "\tinout " + _printsig(ns, sig)
         elif sig in targets:
@@ -194,6 +224,9 @@ def _printheader(f, ios, name, ns,
             r += "\tinput " + _printsig(ns, sig)
     r += "\n);\n\n"
     for sig in sorted(sigs - ios, key=lambda x: x.duid):
+        attr = _printattr(sig, attr_translate)
+        if attr:
+            r += attr + " "
         if sig in wires:
             r += "wire " + _printsig(ns, sig) + ";\n"
         else:
@@ -212,35 +245,25 @@ def _printcomb(f, ns,
     r = ""
     if f.comb:
         if dummy_signal:
-            # Generate a dummy event to get the simulator
-            # to run the combinatorial process once at the beginning.
+            explanation = """
+// Adding a dummy event (using a dummy signal 'dummy_s') to get the simulator
+// to run the combinatorial process once at the beginning.
+"""
             syn_off = "// synthesis translate_off\n"
             syn_on = "// synthesis translate_on\n"
             dummy_s = Signal(name_override="dummy_s")
+            r += explanation
             r += syn_off
             r += "reg " + _printsig(ns, dummy_s) + ";\n"
             r += "initial " + ns.get_name(dummy_s) + " <= 1'd0;\n"
             r += syn_on
-
-       
-        from collections import defaultdict
-
-        target_stmt_map = defaultdict(list)
-
-        for statement in flat_iteration(f.comb):
-            targets = list_targets(statement)
-            for t in targets:
-                target_stmt_map[t].append(statement)
-
-        #from pprint import pprint
-        #pprint(target_stmt_map)
+            r += "\n"
 
         groups = group_by_targets(f.comb)
 
-        for n, (t, stmts) in enumerate(target_stmt_map.items()):
-            assert isinstance(t, Signal)
-            if len(stmts) == 1 and isinstance(stmts[0], _Assign):
-                r += "assign " + _printnode(ns, _AT_BLOCKING, 0, stmts[0])
+        for n, g in enumerate(groups):
+            if len(g[1]) == 1 and isinstance(g[1][0], _Assign):
+                r += "assign " + _printnode(ns, _AT_BLOCKING, 0, g[1][0])
             else:
                 if dummy_signal:
                     dummy_d = Signal(name_override="dummy_d")
@@ -252,15 +275,17 @@ def _printcomb(f, ns,
                 if display_run:
                     r += "\t$display(\"Running comb block #" + str(n) + "\");\n"
                 if blocking_assign:
-                       r += "\t" + ns.get_name(t) + " = " + _printexpr(ns, t.reset)[0] + ";\n"
-                       r += _printnode(ns, _AT_BLOCKING, 1, stmts, t)
+                    for t in g[0]:
+                        r += "\t" + ns.get_name(t) + " = " + _printexpr(ns, t.reset)[0] + ";\n"
+                    r += _printnode(ns, _AT_BLOCKING, 1, g[1])
                 else:
-                       r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset)[0] + ";\n"
-                       r += _printnode(ns, _AT_NONBLOCKING, 1, stmts, t)
+                    for t in g[0]:
+                        r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset)[0] + ";\n"
+                    r += _printnode(ns, _AT_NONBLOCKING, 1, g[1])
                 if dummy_signal:
-                       r += syn_off
-                       r += "\t" + ns.get_name(dummy_d) + " = " + ns.get_name(dummy_s) + ";\n"
-                       r += syn_on
+                    r += syn_off
+                    r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n"
+                    r += syn_on
                 r += "end\n"
     r += "\n"
     return r
@@ -275,61 +300,30 @@ def _printsync(f, ns):
     return r
 
 
-def _call_special_classmethod(overrides, obj, method, *args, **kwargs):
-    cl = obj.__class__
-    if cl in overrides:
-        cl = overrides[cl]
-    if hasattr(cl, method):
-        return getattr(cl, method)(obj, *args, **kwargs)
-    else:
-        return None
-
-
-def _lower_specials_step(overrides, specials):
-    f = _Fragment()
-    lowered_specials = set()
-    for special in sorted(specials, key=lambda x: x.duid):
-        impl = _call_special_classmethod(overrides, special, "lower")
-        if impl is not None:
-            f += impl.get_fragment()
-            lowered_specials.add(special)
-    return f, lowered_specials
-
-
-def _can_lower(overrides, specials):
-    for special in specials:
-        cl = special.__class__
-        if cl in overrides:
-            cl = overrides[cl]
-        if hasattr(cl, "lower"):
-            return True
-    return False
-
-
-def _lower_specials(overrides, specials):
-    f, lowered_specials = _lower_specials_step(overrides, specials)
-    while _can_lower(overrides, f.specials):
-        f2, lowered_specials2 = _lower_specials_step(overrides, f.specials)
-        f += f2
-        lowered_specials |= lowered_specials2
-        f.specials -= lowered_specials2
-    return f, lowered_specials
-
-
 def _printspecials(overrides, specials, ns, add_data_file):
     r = ""
     for special in sorted(specials, key=lambda x: x.duid):
-        pr = _call_special_classmethod(overrides, special, "emit_verilog", ns, add_data_file)
+        pr = call_special_classmethod(overrides, special, "emit_verilog", ns, add_data_file)
         if pr is None:
             raise NotImplementedError("Special " + str(special) + " failed to implement emit_verilog")
         r += pr
     return r
 
 
+class DummyAttrTranslate:
+    def __getitem__(self, k):
+        return (k, "true")
+
+
 def convert(f, ios=None, name="top",
   special_overrides=dict(),
+  attr_translate=DummyAttrTranslate(),
   create_clock_domains=True,
-  display_run=False, asic_syntax=False):
+  display_run=False,
+  reg_initialization=True,
+  dummy_signal=True,
+  blocking_assign=False,
+  regular_comb=True):
     r = ConvOutput()
     if not isinstance(f, _Fragment):
         f = f.get_fragment()
@@ -345,27 +339,35 @@ def convert(f, ios=None, name="top",
                 f.clock_domains.append(cd)
                 ios |= {cd.clk, cd.rst}
             else:
+                print("available clock domains:")
+                for f in f.clock_domains:
+                    print(f.name)
                 raise KeyError("Unresolved clock domain: '"+cd_name+"'")
 
     f = lower_complex_slices(f)
     insert_resets(f)
     f = lower_basics(f)
-    fs, lowered_specials = _lower_specials(special_overrides, f.specials)
+    fs, lowered_specials = lower_specials(special_overrides, f.specials)
     f += lower_basics(fs)
 
+    for io in sorted(ios, key=lambda x: x.duid):
+        if io.name_override is None:
+            io_name = io.backtrace[-1][0]
+            if io_name:
+                io.name_override = io_name
     ns = build_namespace(list_signals(f) \
         | list_special_ios(f, True, True, True) \
         | ios, _reserved_keywords)
     ns.clock_domains = f.clock_domains
     r.ns = ns
 
-    src = "/* Machine-generated using Migen */\n"
-    src += _printheader(f, ios, name, ns,
-                        reg_initialization=not asic_syntax)
+    src = "/* Machine-generated using LiteX gen */\n"
+    src += _printheader(f, ios, name, ns, attr_translate,
+                        reg_initialization=reg_initialization)
     src += _printcomb(f, ns,
                       display_run=display_run,
-                      dummy_signal=not asic_syntax,
-                      blocking_assign=asic_syntax)
+                      dummy_signal=dummy_signal,
+                      blocking_assign=blocking_assign)
     src += _printsync(f, ns)
     src += _printspecials(special_overrides, f.specials - lowered_specials, ns, r.add_data_file)
     src += "endmodule\n"