3 from contextlib
import contextmanager
6 from ..hdl
.ast
import SignalSet
7 from ..hdl
.xfrm
import ValueVisitor
, StatementVisitor
, LHSGroupFilter
8 from ._base
import BaseProcess
11 __all__
= ["PyRTLProcess"]
14 class PyRTLProcess(BaseProcess
):
15 __slots__
= ("is_comb", "runnable", "passive", "run")
17 def __init__(self
, *, is_comb
):
18 self
.is_comb
= is_comb
23 self
.runnable
= self
.is_comb
33 def append(self
, code
):
34 self
._buffer
.append(" " * self
._level
)
35 self
._buffer
.append(code
)
36 self
._buffer
.append("\n")
44 def flush(self
, indent
=""):
45 code
= "".join(self
._buffer
)
49 def gen_var(self
, prefix
):
50 name
= f
"{prefix}_{self._suffix}"
54 def def_var(self
, prefix
, value
):
55 name
= self
.gen_var(prefix
)
56 self
.append(f
"{name} = {value}")
61 def __init__(self
, state
, emitter
):
63 self
.emitter
= emitter
66 class _ValueCompiler(ValueVisitor
, _Compiler
):
68 "sign": lambda value
, sign
: value | sign
if value
& sign
else value
,
69 "zdiv": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
// rhs
,
70 "zmod": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
% rhs
,
73 def on_ClockSignal(self
, value
):
74 raise NotImplementedError # :nocov:
76 def on_ResetSignal(self
, value
):
77 raise NotImplementedError # :nocov:
79 def on_AnyConst(self
, value
):
80 raise NotImplementedError # :nocov:
82 def on_AnySeq(self
, value
):
83 raise NotImplementedError # :nocov:
85 def on_Sample(self
, value
):
86 raise NotImplementedError # :nocov:
88 def on_Initial(self
, value
):
89 raise NotImplementedError # :nocov:
92 class _RHSValueCompiler(_ValueCompiler
):
93 def __init__(self
, state
, emitter
, *, mode
, inputs
=None):
94 super().__init
__(state
, emitter
)
95 assert mode
in ("curr", "next")
97 # If not None, `inputs` gets populated with RHS signals.
100 def on_Const(self
, value
):
101 return f
"{value.value}"
103 def on_Signal(self
, value
):
104 if self
.inputs
is not None:
105 self
.inputs
.add(value
)
107 if self
.mode
== "curr":
108 return f
"slots[{self.state.get_signal(value)}].{self.mode}"
110 return f
"next_{self.state.get_signal(value)}"
112 def on_Operator(self
, value
):
114 value_mask
= (1 << len(value
)) - 1
115 return f
"({value_mask} & {self(value)})"
118 if value
.shape().signed
:
119 return f
"sign({mask(value)}, {-1 << (len(value) - 1)})"
123 if len(value
.operands
) == 1:
124 arg
, = value
.operands
125 if value
.operator
== "~":
126 return f
"(~{self(arg)})"
127 if value
.operator
== "-":
128 return f
"(-{sign(arg)})"
129 if value
.operator
== "b":
130 return f
"bool({mask(arg)})"
131 if value
.operator
== "r|":
132 return f
"(0 != {mask(arg)})"
133 if value
.operator
== "r&":
134 return f
"({(1 << len(arg)) - 1} == {mask(arg)})"
135 if value
.operator
== "r^":
136 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
137 return f
"(format({mask(arg)}, 'b').count('1') % 2)"
138 if value
.operator
in ("u", "s"):
139 # These operators don't change the bit pattern, only its interpretation.
141 elif len(value
.operands
) == 2:
142 lhs
, rhs
= value
.operands
143 lhs_mask
= (1 << len(lhs
)) - 1
144 rhs_mask
= (1 << len(rhs
)) - 1
145 if value
.operator
== "+":
146 return f
"({sign(lhs)} + {sign(rhs)})"
147 if value
.operator
== "-":
148 return f
"({sign(lhs)} - {sign(rhs)})"
149 if value
.operator
== "*":
150 return f
"({sign(lhs)} * {sign(rhs)})"
151 if value
.operator
== "//":
152 return f
"zdiv({sign(lhs)}, {sign(rhs)})"
153 if value
.operator
== "%":
154 return f
"zmod({sign(lhs)}, {sign(rhs)})"
155 if value
.operator
== "&":
156 return f
"({self(lhs)} & {self(rhs)})"
157 if value
.operator
== "|":
158 return f
"({self(lhs)} | {self(rhs)})"
159 if value
.operator
== "^":
160 return f
"({self(lhs)} ^ {self(rhs)})"
161 if value
.operator
== "<<":
162 return f
"({sign(lhs)} << {sign(rhs)})"
163 if value
.operator
== ">>":
164 return f
"({sign(lhs)} >> {sign(rhs)})"
165 if value
.operator
== "==":
166 return f
"({sign(lhs)} == {sign(rhs)})"
167 if value
.operator
== "!=":
168 return f
"({sign(lhs)} != {sign(rhs)})"
169 if value
.operator
== "<":
170 return f
"({sign(lhs)} < {sign(rhs)})"
171 if value
.operator
== "<=":
172 return f
"({sign(lhs)} <= {sign(rhs)})"
173 if value
.operator
== ">":
174 return f
"({sign(lhs)} > {sign(rhs)})"
175 if value
.operator
== ">=":
176 return f
"({sign(lhs)} >= {sign(rhs)})"
177 elif len(value
.operands
) == 3:
178 if value
.operator
== "m":
179 sel
, val1
, val0
= value
.operands
180 return f
"({self(val1)} if {mask(sel)} else {self(val0)})"
181 raise NotImplementedError("Operator '{}' not implemented".format(value
.operator
)) # :nocov:
183 def on_Slice(self
, value
):
184 return f
"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
186 def on_Part(self
, value
):
187 offset_mask
= (1 << len(value
.offset
)) - 1
188 offset
= f
"({value.stride} * ({offset_mask} & {self(value.offset)}))"
189 return f
"({(1 << value.width) - 1} & " \
190 f
"{self(value.value)} >> {offset})"
192 def on_Cat(self
, value
):
195 for part
in value
.parts
:
196 part_mask
= (1 << len(part
)) - 1
197 gen_parts
.append(f
"(({part_mask} & {self(part)}) << {offset})")
200 return f
"({' | '.join(gen_parts)})"
203 def on_Repl(self
, value
):
204 part_mask
= (1 << len(value
.value
)) - 1
205 gen_part
= self
.emitter
.def_var("repl", f
"{part_mask} & {self(value.value)}")
208 for _
in range(value
.count
):
209 gen_parts
.append(f
"({gen_part} << {offset})")
210 offset
+= len(value
.value
)
212 return f
"({' | '.join(gen_parts)})"
215 def on_ArrayProxy(self
, value
):
216 index_mask
= (1 << len(value
.index
)) - 1
217 gen_index
= self
.emitter
.def_var("rhs_index", f
"{index_mask} & {self(value.index)}")
218 gen_value
= self
.emitter
.gen_var("rhs_proxy")
221 for index
, elem
in enumerate(value
.elems
):
223 self
.emitter
.append(f
"if {index} == {gen_index}:")
225 self
.emitter
.append(f
"elif {index} == {gen_index}:")
226 with self
.emitter
.indent():
227 self
.emitter
.append(f
"{gen_value} = {self(elem)}")
228 self
.emitter
.append(f
"else:")
229 with self
.emitter
.indent():
230 self
.emitter
.append(f
"{gen_value} = {self(value.elems[-1])}")
236 def compile(cls
, state
, value
, *, mode
):
237 emitter
= _PythonEmitter()
238 compiler
= cls(state
, emitter
, mode
=mode
)
239 emitter
.append(f
"result = {compiler(value)}")
240 return emitter
.flush()
243 class _LHSValueCompiler(_ValueCompiler
):
244 def __init__(self
, state
, emitter
, *, rhs
, outputs
=None):
245 super().__init
__(state
, emitter
)
246 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
247 # the offset of a Part.
249 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
250 # update of an lvalue.
251 self
.lrhs
= _RHSValueCompiler(state
, emitter
, mode
="next", inputs
=None)
252 # If not None, `outputs` gets populated with signals on LHS.
253 self
.outputs
= outputs
255 def on_Const(self
, value
):
256 raise TypeError # :nocov:
258 def on_Signal(self
, value
):
259 if self
.outputs
is not None:
260 self
.outputs
.add(value
)
263 value_mask
= (1 << len(value
)) - 1
264 if value
.shape().signed
:
265 value_sign
= f
"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
267 value_sign
= f
"{value_mask} & {arg}"
268 self
.emitter
.append(f
"next_{self.state.get_signal(value)} = {value_sign}")
271 def on_Operator(self
, value
):
272 raise TypeError # :nocov:
274 def on_Slice(self
, value
):
276 width_mask
= (1 << (value
.stop
- value
.start
)) - 1
277 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
278 f
"{~(width_mask << value.start)} | " \
279 f
"(({width_mask} & {arg}) << {value.start}))")
282 def on_Part(self
, value
):
284 width_mask
= (1 << value
.width
) - 1
285 offset_mask
= (1 << len(value
.offset
)) - 1
286 offset
= f
"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
287 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
288 f
"~({width_mask} << {offset}) | " \
289 f
"(({width_mask} & {arg}) << {offset}))")
292 def on_Cat(self
, value
):
294 gen_arg
= self
.emitter
.def_var("cat", arg
)
297 for part
in value
.parts
:
298 part_mask
= (1 << len(part
)) - 1
299 self(part
)(f
"({part_mask} & ({gen_arg} >> {offset}))")
303 def on_Repl(self
, value
):
304 raise TypeError # :nocov:
306 def on_ArrayProxy(self
, value
):
308 index_mask
= (1 << len(value
.index
)) - 1
309 gen_index
= self
.emitter
.def_var("index", f
"{self.rrhs(value.index)} & {index_mask}")
312 for index
, elem
in enumerate(value
.elems
):
314 self
.emitter
.append(f
"if {index} == {gen_index}:")
316 self
.emitter
.append(f
"elif {index} == {gen_index}:")
317 with self
.emitter
.indent():
319 self
.emitter
.append(f
"else:")
320 with self
.emitter
.indent():
321 self(value
.elems
[-1])(arg
)
323 self
.emitter
.append(f
"pass")
327 class _StatementCompiler(StatementVisitor
, _Compiler
):
328 def __init__(self
, state
, emitter
, *, inputs
=None, outputs
=None):
329 super().__init
__(state
, emitter
)
330 self
.rhs
= _RHSValueCompiler(state
, emitter
, mode
="curr", inputs
=inputs
)
331 self
.lhs
= _LHSValueCompiler(state
, emitter
, rhs
=self
.rhs
, outputs
=outputs
)
333 def on_statements(self
, stmts
):
337 self
.emitter
.append("pass")
339 def on_Assign(self
, stmt
):
340 gen_rhs
= f
"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
341 if stmt
.rhs
.shape().signed
:
342 gen_rhs
= f
"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
343 return self
.lhs(stmt
.lhs
)(gen_rhs
)
345 def on_Switch(self
, stmt
):
346 gen_test
= self
.emitter
.def_var("test",
347 f
"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
348 for index
, (patterns
, stmts
) in enumerate(stmt
.cases
.items()):
351 gen_checks
.append(f
"True")
353 for pattern
in patterns
:
355 mask
= int("".join("0" if b
== "-" else "1" for b
in pattern
), 2)
356 value
= int("".join("0" if b
== "-" else b
for b
in pattern
), 2)
357 gen_checks
.append(f
"{value} == ({mask} & {gen_test})")
359 value
= int(pattern
, 2)
360 gen_checks
.append(f
"{value} == {gen_test}")
362 self
.emitter
.append(f
"if {' or '.join(gen_checks)}:")
364 self
.emitter
.append(f
"elif {' or '.join(gen_checks)}:")
365 with self
.emitter
.indent():
368 def on_Assert(self
, stmt
):
369 raise NotImplementedError # :nocov:
371 def on_Assume(self
, stmt
):
372 raise NotImplementedError # :nocov:
374 def on_Cover(self
, stmt
):
375 raise NotImplementedError # :nocov:
378 def compile(cls
, state
, stmt
):
379 output_indexes
= [state
.get_signal(signal
) for signal
in stmt
._lhs
_signals
()]
380 emitter
= _PythonEmitter()
381 for signal_index
in output_indexes
:
382 emitter
.append(f
"next_{signal_index} = slots[{signal_index}].next")
383 compiler
= cls(state
, emitter
)
385 for signal_index
in output_indexes
:
386 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
387 return emitter
.flush()
390 class _FragmentCompiler
:
391 def __init__(self
, state
):
394 def __call__(self
, fragment
):
397 for domain_name
, domain_signals
in fragment
.drivers
.items():
398 domain_stmts
= LHSGroupFilter(domain_signals
)(fragment
.statements
)
399 domain_process
= PyRTLProcess(is_comb
=domain_name
is None)
401 emitter
= _PythonEmitter()
402 emitter
.append(f
"def run():")
405 if domain_name
is None:
406 for signal
in domain_signals
:
407 signal_index
= self
.state
.get_signal(signal
)
408 emitter
.append(f
"next_{signal_index} = {signal.reset}")
411 _StatementCompiler(self
.state
, emitter
, inputs
=inputs
)(domain_stmts
)
414 self
.state
.add_trigger(domain_process
, input)
417 domain
= fragment
.domains
[domain_name
]
418 clk_trigger
= 1 if domain
.clk_edge
== "pos" else 0
419 self
.state
.add_trigger(domain_process
, domain
.clk
, trigger
=clk_trigger
)
420 if domain
.rst
is not None and domain
.async_reset
:
422 self
.state
.add_trigger(domain_process
, domain
.rst
, trigger
=rst_trigger
)
424 for signal
in domain_signals
:
425 signal_index
= self
.state
.get_signal(signal
)
426 emitter
.append(f
"next_{signal_index} = slots[{signal_index}].next")
428 _StatementCompiler(self
.state
, emitter
)(domain_stmts
)
430 for signal
in domain_signals
:
431 signal_index
= self
.state
.get_signal(signal
)
432 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
434 # There shouldn't be any exceptions raised by the generated code, but if there are
435 # (almost certainly due to a bug in the code generator), use this environment variable
436 # to make backtraces useful.
437 code
= emitter
.flush()
438 if os
.getenv("NMIGEN_pysim_dump"):
439 file = tempfile
.NamedTemporaryFile("w", prefix
="nmigen_pysim_", delete
=False)
443 filename
= "<string>"
445 exec_locals
= {"slots": self
.state
.slots
, **_ValueCompiler
.helpers
}
446 exec(compile(code
, filename
, "exec"), exec_locals
)
447 domain_process
.run
= exec_locals
["run"]
449 processes
.add(domain_process
)
451 for subfragment_index
, (subfragment
, subfragment_name
) in enumerate(fragment
.subfragments
):
452 if subfragment_name
is None:
453 subfragment_name
= "U${}".format(subfragment_index
)
454 processes
.update(self(subfragment
))