47bf570eea017b6cc0858aefea1aedb95bfa9a49
2 from contextlib
import contextmanager
4 from nmigen
.hdl
.ast
import SignalSet
5 from nmigen
.hdl
.xfrm
import ValueVisitor
, StatementVisitor
, LHSGroupFilter
6 from nmigen
.sim
._base
import BaseProcess
8 __all__
= ["PyRTLProcess"]
11 class PyRTLProcess(BaseProcess
):
12 __slots__
= ("is_comb", "runnable", "passive", "name", "filename", "crtl", "run")
14 def __init__(self
, *, is_comb
):
15 self
.is_comb
= is_comb
19 self
.runnable
= self
.is_comb
29 def append(self
, code
):
30 self
._buffer
.append(" " * self
._level
)
31 self
._buffer
.append(code
)
32 self
._buffer
.append("\n")
49 def flush(self
, indent
=""):
50 code
= "".join(self
._buffer
)
54 def gen_var(self
, prefix
):
55 name
= f
"{prefix}_{self._suffix}"
59 def def_var(self
, prefix
, value
):
60 name
= self
.gen_var(prefix
)
61 self
.append(f
"uint64_t {name} = {value};")
64 def assign(self
, lhs
, rhs
):
65 self
.append(f
"{lhs} = {rhs}")
68 self
.append(f
"if ({cond})")
70 def else_if(self
, cond
):
71 self
.append(f
"else if ({cond})")
78 def __init__(self
, state
, emitter
):
80 self
.emitter
= emitter
83 class _ValueCompiler(ValueVisitor
, _Compiler
):
85 "sign": lambda value
, sign
: value | sign
if value
& sign
else value
,
86 "zdiv": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
// rhs
,
87 "zmod": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
% rhs
,
90 def on_ClockSignal(self
, value
):
91 raise NotImplementedError # :nocov:
93 def on_ResetSignal(self
, value
):
94 raise NotImplementedError # :nocov:
96 def on_AnyConst(self
, value
):
97 raise NotImplementedError # :nocov:
99 def on_AnySeq(self
, value
):
100 raise NotImplementedError # :nocov:
102 def on_Sample(self
, value
):
103 raise NotImplementedError # :nocov:
105 def on_Initial(self
, value
):
106 raise NotImplementedError # :nocov:
109 class _RHSValueCompiler(_ValueCompiler
):
110 def __init__(self
, state
, emitter
, *, mode
, inputs
=None):
111 super().__init
__(state
, emitter
)
112 assert mode
in ("curr", "next")
114 # If not None, `inputs` gets populated with RHS signals.
117 def on_Const(self
, value
):
118 return f
"{value.value}"
120 def on_Signal(self
, value
):
121 if self
.inputs
is not None:
122 self
.inputs
.add(value
)
124 if self
.mode
== "curr":
125 return f
"slots[{self.state.get_signal(value)}].{self.mode}"
127 return f
"next_{self.state.get_signal(value)}"
129 def on_Operator(self
, value
):
131 value_mask
= (1 << len(value
)) - 1
132 return f
"({value_mask} & {self(value)})"
135 if value
.shape().signed
:
136 return f
"sign({mask(value)}, {-1 << (len(value) - 1)})"
140 if len(value
.operands
) == 1:
141 arg
, = value
.operands
142 if value
.operator
== "~":
143 return f
"(~{self(arg)})"
144 if value
.operator
== "-":
145 return f
"(-{sign(arg)})"
146 if value
.operator
== "b":
147 return f
"bool({mask(arg)})"
148 if value
.operator
== "r|":
149 return f
"(0 != {mask(arg)})"
150 if value
.operator
== "r&":
151 return f
"({(1 << len(arg)) - 1} == {mask(arg)})"
152 if value
.operator
== "r^":
153 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
154 return f
"(format({mask(arg)}, 'b').count('1') % 2)"
155 if value
.operator
in ("u", "s"):
156 # These operators don't change the bit pattern, only its interpretation.
158 elif len(value
.operands
) == 2:
159 lhs
, rhs
= value
.operands
160 if value
.operator
== "+":
161 return f
"({sign(lhs)} + {sign(rhs)})"
162 if value
.operator
== "-":
163 return f
"({sign(lhs)} - {sign(rhs)})"
164 if value
.operator
== "*":
165 return f
"({sign(lhs)} * {sign(rhs)})"
166 if value
.operator
== "//":
167 return f
"zdiv({sign(lhs)}, {sign(rhs)})"
168 if value
.operator
== "%":
169 return f
"zmod({sign(lhs)}, {sign(rhs)})"
170 if value
.operator
== "&":
171 return f
"({self(lhs)} & {self(rhs)})"
172 if value
.operator
== "|":
173 return f
"({self(lhs)} | {self(rhs)})"
174 if value
.operator
== "^":
175 return f
"({self(lhs)} ^ {self(rhs)})"
176 if value
.operator
== "<<":
177 return f
"({sign(lhs)} << {sign(rhs)})"
178 if value
.operator
== ">>":
179 return f
"({sign(lhs)} >> {sign(rhs)})"
180 if value
.operator
== "==":
181 return f
"({sign(lhs)} == {sign(rhs)})"
182 if value
.operator
== "!=":
183 return f
"({sign(lhs)} != {sign(rhs)})"
184 if value
.operator
== "<":
185 return f
"({sign(lhs)} < {sign(rhs)})"
186 if value
.operator
== "<=":
187 return f
"({sign(lhs)} <= {sign(rhs)})"
188 if value
.operator
== ">":
189 return f
"({sign(lhs)} > {sign(rhs)})"
190 if value
.operator
== ">=":
191 return f
"({sign(lhs)} >= {sign(rhs)})"
192 elif len(value
.operands
) == 3:
193 if value
.operator
== "m":
194 sel
, val1
, val0
= value
.operands
195 return f
"(({mask(sel)}) ? ({self(val1)}) : ({self(val0)}))"
196 raise NotImplementedError("Operator '{}' not implemented".format(value
.operator
)) # :nocov:
198 def on_Slice(self
, value
):
199 return f
"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
201 def on_Part(self
, value
):
202 offset_mask
= (1 << len(value
.offset
)) - 1
203 offset
= f
"({value.stride} * ({offset_mask} & {self(value.offset)}))"
204 return f
"({(1 << value.width) - 1} & " \
205 f
"{self(value.value)} >> {offset})"
207 def on_Cat(self
, value
):
210 for part
in value
.parts
:
211 part_mask
= (1 << len(part
)) - 1
212 gen_parts
.append(f
"(({part_mask} & {self(part)}) << {offset})")
215 return f
"({' | '.join(gen_parts)})"
218 def on_Repl(self
, value
):
219 part_mask
= (1 << len(value
.value
)) - 1
220 gen_part
= self
.emitter
.def_var("repl", f
"{part_mask} & {self(value.value)}")
223 for _
in range(value
.count
):
224 gen_parts
.append(f
"({gen_part} << {offset})")
225 offset
+= len(value
.value
)
227 return f
"({' | '.join(gen_parts)})"
230 def on_ArrayProxy(self
, value
):
231 index_mask
= (1 << len(value
.index
)) - 1
232 gen_index
= self
.emitter
.def_var("rhs_index", f
"{index_mask} & {self(value.index)}")
233 gen_value
= self
.emitter
.gen_var("rhs_proxy")
235 for index
, elem
in enumerate(value
.elems
):
237 self
.emitter
.if_(f
"{index} == {gen_index}")
239 self
.emitter
.else_if(f
"{index} == {gen_index}")
240 with self
.emitter
.nest():
241 self
.emitter
.assign(f
"{gen_value}", f
"{self(elem)}")
243 with self
.emitter
.nest():
244 self
.emitter
.assign(f
"{gen_value}", f
"{self(value.elems[-1])}")
250 def compile(cls
, state
, value
, *, mode
):
251 emitter
= _PythonEmitter()
252 compiler
= cls(state
, emitter
, mode
=mode
)
253 emitter
.assign(f
"result", f
"{compiler(value)}")
254 return emitter
.flush()
257 class _LHSValueCompiler(_ValueCompiler
):
258 def __init__(self
, state
, emitter
, *, rhs
, outputs
=None):
259 super().__init
__(state
, emitter
)
260 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
261 # the offset of a Part.
263 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
264 # update of an lvalue.
265 self
.lrhs
= _RHSValueCompiler(state
, emitter
, mode
="next", inputs
=None)
266 # If not None, `outputs` gets populated with signals on LHS.
267 self
.outputs
= outputs
269 def on_Const(self
, value
):
270 raise TypeError # :nocov:
272 def on_Signal(self
, value
):
273 if self
.outputs
is not None:
274 self
.outputs
.add(value
)
277 value_mask
= (1 << len(value
)) - 1
279 # TODO: useful trick, actually put the name into the c code
280 # but this has to be done consistently right across the board.
281 # all occurrences of next_{....} have to use the same trick
282 # but at least then the names in the auto-generated c-code
284 #if hasattr(value, "name") and value.name is not None:
286 if value
.shape().signed
:
287 value_sign
= f
"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
289 value_sign
= f
"{value_mask} & {arg}"
290 self
.emitter
.append(f
"next_{name}{self.state.get_signal(value)} = {value_sign};")
293 def on_Operator(self
, value
):
294 raise TypeError # :nocov:
296 def on_Slice(self
, value
):
298 width_mask
= (1 << (value
.stop
- value
.start
)) - 1
299 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
300 f
"{~(width_mask << value.start)} | " \
301 f
"(({width_mask} & {arg}) << {value.start}))")
304 def on_Part(self
, value
):
306 width_mask
= (1 << value
.width
) - 1
307 offset_mask
= (1 << len(value
.offset
)) - 1
308 offset
= f
"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
309 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
310 f
"~({width_mask} << {offset}) | " \
311 f
"(({width_mask} & {arg}) << {offset}))")
314 def on_Cat(self
, value
):
316 gen_arg
= self
.emitter
.def_var("cat", arg
)
318 for part
in value
.parts
:
319 part_mask
= (1 << len(part
)) - 1
320 self(part
)(f
"({part_mask} & ({gen_arg} >> {offset}))")
324 def on_Repl(self
, value
):
325 raise TypeError # :nocov:
327 def on_ArrayProxy(self
, value
):
329 index_mask
= (1 << len(value
.index
)) - 1
330 gen_index
= self
.emitter
.def_var("index", f
"{self.rrhs(value.index)} & {index_mask}")
332 for index
, elem
in enumerate(value
.elems
):
334 self
.emitter
.if_(f
"{index} == {gen_index}")
336 self
.emitter
.append(f
"{index} == {gen_index}")
337 with self
.emitter
.nest():
340 with self
.emitter
.nest():
341 self(value
.elems
[-1])(arg
)
345 class _StatementCompiler(StatementVisitor
, _Compiler
):
346 def __init__(self
, state
, emitter
, *, inputs
=None, outputs
=None):
347 super().__init
__(state
, emitter
)
348 self
.rhs
= _RHSValueCompiler(state
, emitter
, mode
="curr", inputs
=inputs
)
349 self
.lhs
= _LHSValueCompiler(state
, emitter
, rhs
=self
.rhs
, outputs
=outputs
)
351 def on_statements(self
, stmts
):
355 self
.emitter
.append("pass")
357 def on_Assign(self
, stmt
):
358 gen_rhs
= f
"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
359 if stmt
.rhs
.shape().signed
:
360 gen_rhs
= f
"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
361 return self
.lhs(stmt
.lhs
)(gen_rhs
)
363 def on_Switch(self
, stmt
):
364 gen_test
= self
.emitter
.def_var("test",
365 f
"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
366 for index
, (patterns
, stmts
) in enumerate(stmt
.cases
.items()):
369 gen_checks
.append(f
"True")
371 for pattern
in patterns
:
373 mask
= int("".join("0" if b
== "-" else "1" for b
in pattern
), 2)
374 value
= int("".join("0" if b
== "-" else b
for b
in pattern
), 2)
375 gen_checks
.append(f
"{value} == ({mask} & {gen_test})")
377 value
= int(pattern
, 2)
378 gen_checks
.append(f
"{value} == {gen_test}")
380 self
.emitter
.if_(f
"{' or '.join(gen_checks)}")
382 self
.emitter
.else_if(f
"{' or '.join(gen_checks)}")
383 with self
.emitter
.nest():
386 def on_Display(self
, stmt
):
387 raise NotImplementedError # :nocov:
389 def on_Assert(self
, stmt
):
390 raise NotImplementedError # :nocov:
392 def on_Assume(self
, stmt
):
393 raise NotImplementedError # :nocov:
395 def on_Cover(self
, stmt
):
396 raise NotImplementedError # :nocov:
399 def compile(cls
, state
, stmt
):
400 output_indexes
= [state
.get_signal(signal
) for signal
in stmt
._lhs
_signals
()]
401 emitter
= _PythonEmitter()
402 for signal_index
in output_indexes
:
403 emitter
.append(f
"uint64_t next_{signal_index} = slots[{signal_index}].next")
404 compiler
= cls(state
, emitter
)
406 for signal_index
in output_indexes
:
407 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
408 return emitter
.flush()
411 class _FragmentCompiler
:
412 def __init__(self
, state
):
415 def __call__(self
, fragment
):
418 for index
, (domain_name
, domain_signals
) in enumerate(fragment
.drivers
.items()):
419 domain_stmts
= LHSGroupFilter(domain_signals
)(fragment
.statements
)
420 domain_process
= PyRTLProcess(is_comb
=domain_name
is None)
422 domain_process
.name
= f
"{id(fragment)}_{domain_name or ''}_{index}"
424 emitter
= _PythonEmitter()
425 emitter
.append(f
"void run_{domain_process.name}(void)")
427 if domain_name
is None:
428 for signal
in domain_signals
:
429 signal_index
= self
.state
.get_signal(signal
)
430 emitter
.append(f
"uint64_t next_{signal_index} = {signal.reset};")
433 _StatementCompiler(self
.state
, emitter
, inputs
=inputs
)(domain_stmts
)
436 self
.state
.add_trigger(domain_process
, input)
439 domain
= fragment
.domains
[domain_name
]
440 clk_trigger
= 1 if domain
.clk_edge
== "pos" else 0
441 self
.state
.add_trigger(domain_process
, domain
.clk
, trigger
=clk_trigger
)
442 if domain
.rst
is not None and domain
.async_reset
:
444 self
.state
.add_trigger(domain_process
, domain
.rst
, trigger
=rst_trigger
)
446 for signal
in domain_signals
:
447 signal_index
= self
.state
.get_signal(signal
)
448 emitter
.append(f
"uint64_t next_{signal_index} = slots[{signal_index}].next;")
450 _StatementCompiler(self
.state
, emitter
)(domain_stmts
)
452 for signal
in domain_signals
:
453 signal_index
= self
.state
.get_signal(signal
)
454 emitter
.append(f
"set({signal_index}, next_{signal_index});")
456 code
= "#include <stdint.h>\n"
457 code
+= "#include \"common.h\"\n"
458 code
+= emitter
.flush()
460 file = open(f
"crtl/{domain_process.name}.c", "w")
464 processes
.add(domain_process
)
466 for subfragment_index
, (subfragment
, subfragment_name
) in enumerate(fragment
.subfragments
):
467 if subfragment_name
is None:
468 subfragment_name
= "U${}".format(subfragment_index
)
469 processes
.update(self(subfragment
))