3 from contextlib
import contextmanager
6 from ..hdl
.ast
import SignalSet
7 from ..hdl
.xfrm
import ValueVisitor
, StatementVisitor
, LHSGroupFilter
8 from ._base
import BaseProcess
11 __all__
= ["PyRTLProcess"]
14 class PyRTLProcess(BaseProcess
):
15 __slots__
= ("is_comb", "runnable", "passive", "run")
17 def __init__(self
, *, is_comb
):
18 self
.is_comb
= is_comb
23 self
.runnable
= self
.is_comb
33 def append(self
, code
):
34 self
._buffer
.append(" " * self
._level
)
35 self
._buffer
.append(code
)
36 self
._buffer
.append("\n")
44 def flush(self
, indent
=""):
45 code
= "".join(self
._buffer
)
49 def gen_var(self
, prefix
):
50 name
= f
"{prefix}_{self._suffix}"
54 def def_var(self
, prefix
, value
):
55 name
= self
.gen_var(prefix
)
56 self
.append(f
"{name} = {value}")
61 def __init__(self
, state
, emitter
):
63 self
.emitter
= emitter
66 class _ValueCompiler(ValueVisitor
, _Compiler
):
68 "sign": lambda value
, sign
: value | sign
if value
& sign
else value
,
69 "zdiv": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
// rhs
,
70 "zmod": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
% rhs
,
73 def on_ClockSignal(self
, value
):
74 raise NotImplementedError # :nocov:
76 def on_ResetSignal(self
, value
):
77 raise NotImplementedError # :nocov:
79 def on_AnyConst(self
, value
):
80 raise NotImplementedError # :nocov:
82 def on_AnySeq(self
, value
):
83 raise NotImplementedError # :nocov:
85 def on_Sample(self
, value
):
86 raise NotImplementedError # :nocov:
88 def on_Initial(self
, value
):
89 raise NotImplementedError # :nocov:
92 class _RHSValueCompiler(_ValueCompiler
):
93 def __init__(self
, state
, emitter
, *, mode
, inputs
=None):
94 super().__init
__(state
, emitter
)
95 assert mode
in ("curr", "next")
97 # If not None, `inputs` gets populated with RHS signals.
100 def on_Const(self
, value
):
101 return f
"{value.value}"
103 def on_Signal(self
, value
):
104 if self
.inputs
is not None:
105 self
.inputs
.add(value
)
107 if self
.mode
== "curr":
108 return f
"slots[{self.state.get_signal(value)}].{self.mode}"
110 return f
"next_{self.state.get_signal(value)}"
112 def on_Operator(self
, value
):
114 value_mask
= (1 << len(value
)) - 1
115 return f
"({value_mask} & {self(value)})"
118 if value
.shape().signed
:
119 return f
"sign({mask(value)}, {-1 << (len(value) - 1)})"
123 if len(value
.operands
) == 1:
124 arg
, = value
.operands
125 if value
.operator
== "~":
126 return f
"(~{self(arg)})"
127 if value
.operator
== "-":
128 return f
"(-{sign(arg)})"
129 if value
.operator
== "b":
130 return f
"bool({mask(arg)})"
131 if value
.operator
== "r|":
132 return f
"(0 != {mask(arg)})"
133 if value
.operator
== "r&":
134 return f
"({(1 << len(arg)) - 1} == {mask(arg)})"
135 if value
.operator
== "r^":
136 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
137 return f
"(format({mask(arg)}, 'b').count('1') % 2)"
138 if value
.operator
in ("u", "s"):
139 # These operators don't change the bit pattern, only its interpretation.
141 elif len(value
.operands
) == 2:
142 lhs
, rhs
= value
.operands
143 if value
.operator
== "+":
144 return f
"({sign(lhs)} + {sign(rhs)})"
145 if value
.operator
== "-":
146 return f
"({sign(lhs)} - {sign(rhs)})"
147 if value
.operator
== "*":
148 return f
"({sign(lhs)} * {sign(rhs)})"
149 if value
.operator
== "//":
150 return f
"zdiv({sign(lhs)}, {sign(rhs)})"
151 if value
.operator
== "%":
152 return f
"zmod({sign(lhs)}, {sign(rhs)})"
153 if value
.operator
== "&":
154 return f
"({self(lhs)} & {self(rhs)})"
155 if value
.operator
== "|":
156 return f
"({self(lhs)} | {self(rhs)})"
157 if value
.operator
== "^":
158 return f
"({self(lhs)} ^ {self(rhs)})"
159 if value
.operator
== "<<":
160 return f
"({sign(lhs)} << {sign(rhs)})"
161 if value
.operator
== ">>":
162 return f
"({sign(lhs)} >> {sign(rhs)})"
163 if value
.operator
== "==":
164 return f
"({sign(lhs)} == {sign(rhs)})"
165 if value
.operator
== "!=":
166 return f
"({sign(lhs)} != {sign(rhs)})"
167 if value
.operator
== "<":
168 return f
"({sign(lhs)} < {sign(rhs)})"
169 if value
.operator
== "<=":
170 return f
"({sign(lhs)} <= {sign(rhs)})"
171 if value
.operator
== ">":
172 return f
"({sign(lhs)} > {sign(rhs)})"
173 if value
.operator
== ">=":
174 return f
"({sign(lhs)} >= {sign(rhs)})"
175 elif len(value
.operands
) == 3:
176 if value
.operator
== "m":
177 sel
, val1
, val0
= value
.operands
178 return f
"({self(val1)} if {mask(sel)} else {self(val0)})"
179 raise NotImplementedError("Operator '{}' not implemented".format(value
.operator
)) # :nocov:
181 def on_Slice(self
, value
):
182 return f
"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
184 def on_Part(self
, value
):
185 offset_mask
= (1 << len(value
.offset
)) - 1
186 offset
= f
"({value.stride} * ({offset_mask} & {self(value.offset)}))"
187 return f
"({(1 << value.width) - 1} & " \
188 f
"{self(value.value)} >> {offset})"
190 def on_Cat(self
, value
):
193 for part
in value
.parts
:
194 part_mask
= (1 << len(part
)) - 1
195 gen_parts
.append(f
"(({part_mask} & {self(part)}) << {offset})")
198 return f
"({' | '.join(gen_parts)})"
201 def on_Repl(self
, value
):
202 part_mask
= (1 << len(value
.value
)) - 1
203 gen_part
= self
.emitter
.def_var("repl", f
"{part_mask} & {self(value.value)}")
206 for _
in range(value
.count
):
207 gen_parts
.append(f
"({gen_part} << {offset})")
208 offset
+= len(value
.value
)
210 return f
"({' | '.join(gen_parts)})"
213 def on_ArrayProxy(self
, value
):
214 index_mask
= (1 << len(value
.index
)) - 1
215 gen_index
= self
.emitter
.def_var("rhs_index", f
"{index_mask} & {self(value.index)}")
216 gen_value
= self
.emitter
.gen_var("rhs_proxy")
218 for index
, elem
in enumerate(value
.elems
):
220 self
.emitter
.append(f
"if {index} == {gen_index}:")
222 self
.emitter
.append(f
"elif {index} == {gen_index}:")
223 with self
.emitter
.indent():
224 self
.emitter
.append(f
"{gen_value} = {self(elem)}")
225 self
.emitter
.append(f
"else:")
226 with self
.emitter
.indent():
227 self
.emitter
.append(f
"{gen_value} = {self(value.elems[-1])}")
233 def compile(cls
, state
, value
, *, mode
):
234 emitter
= _PythonEmitter()
235 compiler
= cls(state
, emitter
, mode
=mode
)
236 emitter
.append(f
"result = {compiler(value)}")
237 return emitter
.flush()
240 class _LHSValueCompiler(_ValueCompiler
):
241 def __init__(self
, state
, emitter
, *, rhs
, outputs
=None):
242 super().__init
__(state
, emitter
)
243 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
244 # the offset of a Part.
246 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
247 # update of an lvalue.
248 self
.lrhs
= _RHSValueCompiler(state
, emitter
, mode
="next", inputs
=None)
249 # If not None, `outputs` gets populated with signals on LHS.
250 self
.outputs
= outputs
252 def on_Const(self
, value
):
253 raise TypeError # :nocov:
255 def on_Signal(self
, value
):
256 if self
.outputs
is not None:
257 self
.outputs
.add(value
)
260 value_mask
= (1 << len(value
)) - 1
261 if value
.shape().signed
:
262 value_sign
= f
"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
264 value_sign
= f
"{value_mask} & {arg}"
265 self
.emitter
.append(f
"next_{self.state.get_signal(value)} = {value_sign}")
268 def on_Operator(self
, value
):
269 raise TypeError # :nocov:
271 def on_Slice(self
, value
):
273 width_mask
= (1 << (value
.stop
- value
.start
)) - 1
274 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
275 f
"{~(width_mask << value.start)} | " \
276 f
"(({width_mask} & {arg}) << {value.start}))")
279 def on_Part(self
, value
):
281 width_mask
= (1 << value
.width
) - 1
282 offset_mask
= (1 << len(value
.offset
)) - 1
283 offset
= f
"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
284 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
285 f
"~({width_mask} << {offset}) | " \
286 f
"(({width_mask} & {arg}) << {offset}))")
289 def on_Cat(self
, value
):
291 gen_arg
= self
.emitter
.def_var("cat", arg
)
293 for part
in value
.parts
:
294 part_mask
= (1 << len(part
)) - 1
295 self(part
)(f
"({part_mask} & ({gen_arg} >> {offset}))")
299 def on_Repl(self
, value
):
300 raise TypeError # :nocov:
302 def on_ArrayProxy(self
, value
):
304 index_mask
= (1 << len(value
.index
)) - 1
305 gen_index
= self
.emitter
.def_var("index", f
"{self.rrhs(value.index)} & {index_mask}")
307 for index
, elem
in enumerate(value
.elems
):
309 self
.emitter
.append(f
"if {index} == {gen_index}:")
311 self
.emitter
.append(f
"elif {index} == {gen_index}:")
312 with self
.emitter
.indent():
314 self
.emitter
.append(f
"else:")
315 with self
.emitter
.indent():
316 self(value
.elems
[-1])(arg
)
318 self
.emitter
.append(f
"pass")
322 class _StatementCompiler(StatementVisitor
, _Compiler
):
323 def __init__(self
, state
, emitter
, *, inputs
=None, outputs
=None):
324 super().__init
__(state
, emitter
)
325 self
.rhs
= _RHSValueCompiler(state
, emitter
, mode
="curr", inputs
=inputs
)
326 self
.lhs
= _LHSValueCompiler(state
, emitter
, rhs
=self
.rhs
, outputs
=outputs
)
328 def on_statements(self
, stmts
):
332 self
.emitter
.append("pass")
334 def on_Assign(self
, stmt
):
335 gen_rhs
= f
"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
336 if stmt
.rhs
.shape().signed
:
337 gen_rhs
= f
"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
338 return self
.lhs(stmt
.lhs
)(gen_rhs
)
340 def on_Switch(self
, stmt
):
341 gen_test
= self
.emitter
.def_var("test",
342 f
"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
343 for index
, (patterns
, stmts
) in enumerate(stmt
.cases
.items()):
346 gen_checks
.append(f
"True")
348 for pattern
in patterns
:
350 mask
= int("".join("0" if b
== "-" else "1" for b
in pattern
), 2)
351 value
= int("".join("0" if b
== "-" else b
for b
in pattern
), 2)
352 gen_checks
.append(f
"{value} == ({mask} & {gen_test})")
354 value
= int(pattern
, 2)
355 gen_checks
.append(f
"{value} == {gen_test}")
357 self
.emitter
.append(f
"if {' or '.join(gen_checks)}:")
359 self
.emitter
.append(f
"elif {' or '.join(gen_checks)}:")
360 with self
.emitter
.indent():
363 def on_Assert(self
, stmt
):
364 raise NotImplementedError # :nocov:
366 def on_Assume(self
, stmt
):
367 raise NotImplementedError # :nocov:
369 def on_Cover(self
, stmt
):
370 raise NotImplementedError # :nocov:
373 def compile(cls
, state
, stmt
):
374 output_indexes
= [state
.get_signal(signal
) for signal
in stmt
._lhs
_signals
()]
375 emitter
= _PythonEmitter()
376 for signal_index
in output_indexes
:
377 emitter
.append(f
"next_{signal_index} = slots[{signal_index}].next")
378 compiler
= cls(state
, emitter
)
380 for signal_index
in output_indexes
:
381 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
382 return emitter
.flush()
385 class _FragmentCompiler
:
386 def __init__(self
, state
):
389 def __call__(self
, fragment
):
392 for domain_name
, domain_signals
in fragment
.drivers
.items():
393 domain_stmts
= LHSGroupFilter(domain_signals
)(fragment
.statements
)
394 domain_process
= PyRTLProcess(is_comb
=domain_name
is None)
396 emitter
= _PythonEmitter()
397 emitter
.append(f
"def run():")
400 if domain_name
is None:
401 for signal
in domain_signals
:
402 signal_index
= self
.state
.get_signal(signal
)
403 emitter
.append(f
"next_{signal_index} = {signal.reset}")
406 _StatementCompiler(self
.state
, emitter
, inputs
=inputs
)(domain_stmts
)
409 self
.state
.add_trigger(domain_process
, input)
412 domain
= fragment
.domains
[domain_name
]
413 clk_trigger
= 1 if domain
.clk_edge
== "pos" else 0
414 self
.state
.add_trigger(domain_process
, domain
.clk
, trigger
=clk_trigger
)
415 if domain
.rst
is not None and domain
.async_reset
:
417 self
.state
.add_trigger(domain_process
, domain
.rst
, trigger
=rst_trigger
)
419 for signal
in domain_signals
:
420 signal_index
= self
.state
.get_signal(signal
)
421 emitter
.append(f
"next_{signal_index} = slots[{signal_index}].next")
423 _StatementCompiler(self
.state
, emitter
)(domain_stmts
)
425 for signal
in domain_signals
:
426 signal_index
= self
.state
.get_signal(signal
)
427 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
429 # There shouldn't be any exceptions raised by the generated code, but if there are
430 # (almost certainly due to a bug in the code generator), use this environment variable
431 # to make backtraces useful.
432 code
= emitter
.flush()
433 if os
.getenv("NMIGEN_pysim_dump"):
434 file = tempfile
.NamedTemporaryFile("w", prefix
="nmigen_pysim_", delete
=False)
438 filename
= "<string>"
440 exec_locals
= {"slots": self
.state
.slots
, **_ValueCompiler
.helpers
}
441 exec(compile(code
, filename
, "exec"), exec_locals
)
442 domain_process
.run
= exec_locals
["run"]
444 processes
.add(domain_process
)
446 for subfragment_index
, (subfragment
, subfragment_name
) in enumerate(fragment
.subfragments
):
447 if subfragment_name
is None:
448 subfragment_name
= "U${}".format(subfragment_index
)
449 processes
.update(self(subfragment
))