add TODOs
[litex.git] / litex / gen / fhdl / verilog.py
1 from functools import partial
2 from operator import itemgetter
3 import collections
4
5 from litex.gen.fhdl.structure import *
6 from litex.gen.fhdl.structure import _Operator, _Slice, _Assign, _Fragment
7 from litex.gen.fhdl.tools import *
8 from litex.gen.fhdl.bitcontainer import bits_for
9 from litex.gen.fhdl.namer import build_namespace
10 from litex.gen.fhdl.conv_output import ConvOutput
11
12 # TODO: clean up simulation hack
13
14 _reserved_keywords = {
15 "always", "and", "assign", "automatic", "begin", "buf", "bufif0", "bufif1",
16 "case", "casex", "casez", "cell", "cmos", "config", "deassign", "default",
17 "defparam", "design", "disable", "edge", "else", "end", "endcase",
18 "endconfig", "endfunction", "endgenerate", "endmodule", "endprimitive",
19 "endspecify", "endtable", "endtask", "event", "for", "force", "forever",
20 "fork", "function", "generate", "genvar", "highz0", "highz1", "if",
21 "ifnone", "incdir", "include", "initial", "inout", "input",
22 "instance", "integer", "join", "large", "liblist", "library", "localparam",
23 "macromodule", "medium", "module", "nand", "negedge", "nmos", "nor",
24 "noshowcancelled", "not", "notif0", "notif1", "or", "output", "parameter",
25 "pmos", "posedge", "primitive", "pull0", "pull1" "pulldown",
26 "pullup", "pulsestyle_onevent", "pulsestyle_ondetect", "remos", "real",
27 "realtime", "reg", "release", "repeat", "rnmos", "rpmos", "rtran",
28 "rtranif0", "rtranif1", "scalared", "showcancelled", "signed", "small",
29 "specify", "specparam", "strong0", "strong1", "supply0", "supply1",
30 "table", "task", "time", "tran", "tranif0", "tranif1", "tri", "tri0",
31 "tri1", "triand", "trior", "trireg", "unsigned", "use", "vectored", "wait",
32 "wand", "weak0", "weak1", "while", "wire", "wor","xnor", "xor"
33 }
34
35
36 def _printsig(ns, s):
37 if s.signed:
38 n = "signed "
39 else:
40 n = ""
41 if len(s) > 1:
42 n += "[" + str(len(s)-1) + ":0] "
43 n += ns.get_name(s)
44 return n
45
46
47 def _printconstant(node):
48 if node.signed:
49 return (str(node.nbits) + "'sd" + str(2**node.nbits + node.value),
50 True)
51 else:
52 return str(node.nbits) + "'d" + str(node.value), False
53
54
55 def _printexpr(ns, node):
56 if isinstance(node, Constant):
57 return _printconstant(node)
58 elif isinstance(node, Signal):
59 return ns.get_name(node), node.signed
60 elif isinstance(node, _Operator):
61 arity = len(node.operands)
62 r1, s1 = _printexpr(ns, node.operands[0])
63 if arity == 1:
64 if node.op == "-":
65 if s1:
66 r = node.op + r1
67 else:
68 r = "-$signed({1'd0, " + r1 + "})"
69 s = True
70 else:
71 r = node.op + r1
72 s = s1
73 elif arity == 2:
74 r2, s2 = _printexpr(ns, node.operands[1])
75 if node.op not in ["<<<", ">>>"]:
76 if s2 and not s1:
77 r1 = "$signed({1'd0, " + r1 + "})"
78 if s1 and not s2:
79 r2 = "$signed({1'd0, " + r2 + "})"
80 r = r1 + " " + node.op + " " + r2
81 s = s1 or s2
82 elif arity == 3:
83 assert node.op == "m"
84 r2, s2 = _printexpr(ns, node.operands[1])
85 r3, s3 = _printexpr(ns, node.operands[2])
86 if s2 and not s3:
87 r3 = "$signed({1'd0, " + r3 + "})"
88 if s3 and not s2:
89 r2 = "$signed({1'd0, " + r2 + "})"
90 r = r1 + " ? " + r2 + " : " + r3
91 s = s2 or s3
92 else:
93 raise TypeError
94 return "(" + r + ")", s
95 elif isinstance(node, _Slice):
96 # Verilog does not like us slicing non-array signals...
97 if isinstance(node.value, Signal) \
98 and len(node.value) == 1 \
99 and node.start == 0 and node.stop == 1:
100 return _printexpr(ns, node.value)
101
102 if node.start + 1 == node.stop:
103 sr = "[" + str(node.start) + "]"
104 else:
105 sr = "[" + str(node.stop-1) + ":" + str(node.start) + "]"
106 r, s = _printexpr(ns, node.value)
107 return r + sr, s
108 elif isinstance(node, Cat):
109 l = [_printexpr(ns, v)[0] for v in reversed(node.l)]
110 return "{" + ", ".join(l) + "}", False
111 elif isinstance(node, Replicate):
112 return "{" + str(node.n) + "{" + _printexpr(ns, node.v)[0] + "}}", False
113 else:
114 raise TypeError("Expression of unrecognized type: '{}'".format(type(node).__name__))
115
116
117 (_AT_BLOCKING, _AT_NONBLOCKING, _AT_SIGNAL) = range(3)
118
119
120 def _printnode(ns, at, level, node, target_filter=None):
121 if node is None:
122 return ""
123 elif target_filter is not None and target_filter not in list_targets(node):
124 return ""
125 elif isinstance(node, _Assign):
126 if at == _AT_BLOCKING:
127 assignment = " = "
128 elif at == _AT_NONBLOCKING:
129 assignment = " <= "
130 elif is_variable(node.l):
131 assignment = " = "
132 else:
133 assignment = " <= "
134 return "\t"*level + _printexpr(ns, node.l)[0] + assignment + _printexpr(ns, node.r)[0] + ";\n"
135 elif isinstance(node, collections.Iterable):
136 return "".join(_printnode(ns, at, level, n, target_filter) for n in node)
137 elif isinstance(node, If):
138 r = "\t"*level + "if (" + _printexpr(ns, node.cond)[0] + ") begin\n"
139 r += _printnode(ns, at, level + 1, node.t, target_filter)
140 if node.f:
141 r += "\t"*level + "end else begin\n"
142 r += _printnode(ns, at, level + 1, node.f, target_filter)
143 r += "\t"*level + "end\n"
144 return r
145 elif isinstance(node, Case):
146 if node.cases:
147 r = "\t"*level + "case (" + _printexpr(ns, node.test)[0] + ")\n"
148 css = [(k, v) for k, v in node.cases.items() if isinstance(k, Constant)]
149 css = sorted(css, key=lambda x: x[0].value)
150 for choice, statements in css:
151 r += "\t"*(level + 1) + _printexpr(ns, choice)[0] + ": begin\n"
152 r += _printnode(ns, at, level + 2, statements, target_filter)
153 r += "\t"*(level + 1) + "end\n"
154 if "default" in node.cases:
155 r += "\t"*(level + 1) + "default: begin\n"
156 r += _printnode(ns, at, level + 2, node.cases["default"], target_filter)
157 r += "\t"*(level + 1) + "end\n"
158 r += "\t"*level + "endcase\n"
159 return r
160 else:
161 return ""
162 else:
163 raise TypeError("Node of unrecognized type: "+str(type(node)))
164
165
166 def _list_comb_wires(f):
167 r = set()
168 groups = group_by_targets(f.comb)
169 for g in groups:
170 if len(g[1]) == 1 and isinstance(g[1][0], _Assign):
171 r |= g[0]
172 return r
173
174
175 def _printheader(f, ios, name, ns,
176 reg_initialization):
177 sigs = list_signals(f) | list_special_ios(f, True, True, True)
178 special_outs = list_special_ios(f, False, True, True)
179 inouts = list_special_ios(f, False, False, True)
180 targets = list_targets(f) | special_outs
181 wires = _list_comb_wires(f) | special_outs
182 r = "module " + name + "(\n"
183 firstp = True
184 for sig in sorted(ios, key=lambda x: x.duid):
185 if not firstp:
186 r += ",\n"
187 firstp = False
188 if sig in inouts:
189 r += "\tinout " + _printsig(ns, sig)
190 elif sig in targets:
191 if sig in wires:
192 r += "\toutput " + _printsig(ns, sig)
193 else:
194 r += "\toutput reg " + _printsig(ns, sig)
195 else:
196 r += "\tinput " + _printsig(ns, sig)
197 r += "\n);\n\n"
198 for sig in sorted(sigs - ios, key=lambda x: x.duid):
199 if sig in wires:
200 r += "wire " + _printsig(ns, sig) + ";\n"
201 else:
202 if reg_initialization:
203 r += "reg " + _printsig(ns, sig) + " = " + _printexpr(ns, sig.reset)[0] + ";\n"
204 else:
205 r += "reg " + _printsig(ns, sig) + ";\n"
206 r += "\n"
207 return r
208
209
210 def _printcomb(f, ns,
211 display_run,
212 dummy_signal,
213 blocking_assign):
214 r = ""
215 if f.comb:
216 if dummy_signal:
217 # Generate a dummy event to get the simulator
218 # to run the combinatorial process once at the beginning.
219 syn_off = "// synthesis translate_off\n"
220 syn_on = "// synthesis translate_on\n"
221 dummy_s = Signal(name_override="dummy_s")
222 r += syn_off
223 r += "reg " + _printsig(ns, dummy_s) + ";\n"
224 r += "initial " + ns.get_name(dummy_s) + " <= 1'd0;\n"
225 r += syn_on
226
227
228 from collections import defaultdict
229
230 target_stmt_map = defaultdict(list)
231
232 for statement in flat_iteration(f.comb):
233 targets = list_targets(statement)
234 for t in targets:
235 target_stmt_map[t].append(statement)
236
237 #from pprint import pprint
238 #pprint(target_stmt_map)
239
240 groups = group_by_targets(f.comb)
241
242 for n, (t, stmts) in enumerate(target_stmt_map.items()):
243 assert isinstance(t, Signal)
244 if len(stmts) == 1 and isinstance(stmts[0], _Assign):
245 r += "assign " + _printnode(ns, _AT_BLOCKING, 0, stmts[0])
246 else:
247 if dummy_signal:
248 dummy_d = Signal(name_override="dummy_d")
249 r += "\n" + syn_off
250 r += "reg " + _printsig(ns, dummy_d) + ";\n"
251 r += syn_on
252
253 r += "always @(*) begin\n"
254 if display_run:
255 r += "\t$display(\"Running comb block #" + str(n) + "\");\n"
256 if blocking_assign:
257 r += "\t" + ns.get_name(t) + " = " + _printexpr(ns, t.reset)[0] + ";\n"
258 r += _printnode(ns, _AT_BLOCKING, 1, stmts, t)
259 else:
260 r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset)[0] + ";\n"
261 r += _printnode(ns, _AT_NONBLOCKING, 1, stmts, t)
262 if dummy_signal:
263 r += syn_off
264 r += "\t" + ns.get_name(dummy_d) + " = " + ns.get_name(dummy_s) + ";\n"
265 r += syn_on
266 r += "end\n"
267 r += "\n"
268 return r
269
270
271 def _printsync(f, ns):
272 r = ""
273 for k, v in sorted(f.sync.items(), key=itemgetter(0)):
274 r += "always @(posedge " + ns.get_name(f.clock_domains[k].clk) + ") begin\n"
275 r += _printnode(ns, _AT_SIGNAL, 1, v)
276 r += "end\n\n"
277 return r
278
279
280 def _call_special_classmethod(overrides, obj, method, *args, **kwargs):
281 cl = obj.__class__
282 if cl in overrides:
283 cl = overrides[cl]
284 if hasattr(cl, method):
285 return getattr(cl, method)(obj, *args, **kwargs)
286 else:
287 return None
288
289
290 def _lower_specials_step(overrides, specials):
291 f = _Fragment()
292 lowered_specials = set()
293 for special in sorted(specials, key=lambda x: x.duid):
294 impl = _call_special_classmethod(overrides, special, "lower")
295 if impl is not None:
296 f += impl.get_fragment()
297 lowered_specials.add(special)
298 return f, lowered_specials
299
300
301 def _can_lower(overrides, specials):
302 for special in specials:
303 cl = special.__class__
304 if cl in overrides:
305 cl = overrides[cl]
306 if hasattr(cl, "lower"):
307 return True
308 return False
309
310
311 def _lower_specials(overrides, specials):
312 f, lowered_specials = _lower_specials_step(overrides, specials)
313 while _can_lower(overrides, f.specials):
314 f2, lowered_specials2 = _lower_specials_step(overrides, f.specials)
315 f += f2
316 lowered_specials |= lowered_specials2
317 f.specials -= lowered_specials2
318 return f, lowered_specials
319
320
321 def _printspecials(overrides, specials, ns, add_data_file):
322 r = ""
323 for special in sorted(specials, key=lambda x: x.duid):
324 pr = _call_special_classmethod(overrides, special, "emit_verilog", ns, add_data_file)
325 if pr is None:
326 raise NotImplementedError("Special " + str(special) + " failed to implement emit_verilog")
327 r += pr
328 return r
329
330
331 def convert(f, ios=None, name="top",
332 special_overrides=dict(),
333 create_clock_domains=True,
334 display_run=False, asic_syntax=False):
335 r = ConvOutput()
336 if not isinstance(f, _Fragment):
337 f = f.get_fragment()
338 if ios is None:
339 ios = set()
340
341 for cd_name in sorted(list_clock_domains(f)):
342 try:
343 f.clock_domains[cd_name]
344 except KeyError:
345 if create_clock_domains:
346 cd = ClockDomain(cd_name)
347 f.clock_domains.append(cd)
348 ios |= {cd.clk, cd.rst}
349 else:
350 raise KeyError("Unresolved clock domain: '"+cd_name+"'")
351
352 f = lower_complex_slices(f)
353 insert_resets(f)
354 f = lower_basics(f)
355 fs, lowered_specials = _lower_specials(special_overrides, f.specials)
356 f += lower_basics(fs)
357
358 ns = build_namespace(list_signals(f) \
359 | list_special_ios(f, True, True, True) \
360 | ios, _reserved_keywords)
361 ns.clock_domains = f.clock_domains
362 r.ns = ns
363
364 src = "/* Machine-generated using LiteX gen*/\n"
365 src += _printheader(f, ios, name, ns,
366 reg_initialization=not asic_syntax)
367 src += _printcomb(f, ns,
368 display_run=display_run,
369 dummy_signal=not asic_syntax,
370 blocking_assign=asic_syntax)
371 src += _printsync(f, ns)
372 src += _printspecials(special_overrides, f.specials - lowered_specials, ns, r.add_data_file)
373 src += "endmodule\n"
374 r.set_main_source(src)
375
376 return r