2 from contextlib
import contextmanager
4 from nmigen
.hdl
.ast
import SignalSet
5 from nmigen
.hdl
.xfrm
import ValueVisitor
, StatementVisitor
, LHSGroupFilter
6 from nmigen
.sim
._base
import BaseProcess
7 from openpower
.decoder
.test
.crtl_path
import get_crtl_path
9 __all__
= ["PyRTLProcess"]
12 class PyRTLProcess(BaseProcess
):
13 __slots__
= ("is_comb", "runnable", "passive", "name", "crtl", "run")
15 def __init__(self
, *, is_comb
):
16 self
.is_comb
= is_comb
20 self
.runnable
= self
.is_comb
30 def append(self
, code
):
31 self
._buffer
.append(" " * self
._level
)
32 self
._buffer
.append(code
)
33 self
._buffer
.append("\n")
50 def flush(self
, indent
=""):
51 code
= "".join(self
._buffer
)
55 def gen_var(self
, prefix
):
56 name
= f
"{prefix}_{self._suffix}"
60 def def_var(self
, prefix
, value
):
61 name
= self
.gen_var(prefix
)
62 self
.append(f
"uint64_t {name} = {value};")
65 def assign(self
, lhs
, rhs
):
66 self
.append(f
"{lhs} = {rhs}")
69 self
.append(f
"if ({cond})")
71 def else_if(self
, cond
):
72 self
.append(f
"else if ({cond})")
79 def __init__(self
, state
, emitter
):
81 self
.emitter
= emitter
84 class _ValueCompiler(ValueVisitor
, _Compiler
):
86 "sign": lambda value
, sign
: value | sign
if value
& sign
else value
,
87 "zdiv": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
// rhs
,
88 "zmod": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
% rhs
,
91 def on_ClockSignal(self
, value
):
92 raise NotImplementedError # :nocov:
94 def on_ResetSignal(self
, value
):
95 raise NotImplementedError # :nocov:
97 def on_AnyConst(self
, value
):
98 raise NotImplementedError # :nocov:
100 def on_AnySeq(self
, value
):
101 raise NotImplementedError # :nocov:
103 def on_Sample(self
, value
):
104 raise NotImplementedError # :nocov:
106 def on_Initial(self
, value
):
107 raise NotImplementedError # :nocov:
110 class _RHSValueCompiler(_ValueCompiler
):
111 def __init__(self
, state
, emitter
, *, mode
, inputs
=None):
112 super().__init
__(state
, emitter
)
113 assert mode
in ("curr", "next")
115 # If not None, `inputs` gets populated with RHS signals.
118 def on_SmtExpr(self
, value
):
119 raise NotImplementedError
121 def on_Const(self
, value
):
122 return f
"{value.value}"
124 def on_Signal(self
, value
):
125 if self
.inputs
is not None:
126 self
.inputs
.add(value
)
128 macro
= self
.state
.get_signal_macro(value
)
129 if self
.mode
== "curr":
130 return f
"slots[{macro}].{self.mode}"
132 return f
"next_{macro}"
134 def on_Operator(self
, value
):
136 value_mask
= (1 << len(value
)) - 1
137 return f
"({value_mask} & {self(value)})"
140 if value
.shape().signed
:
141 return f
"sign({mask(value)}, {-1 << (len(value) - 1)})"
145 if len(value
.operands
) == 1:
146 arg
, = value
.operands
147 if value
.operator
== "~":
148 return f
"(~{self(arg)})"
149 if value
.operator
== "-":
150 return f
"(-{sign(arg)})"
151 if value
.operator
== "b":
152 return f
"!!({mask(arg)})"
153 if value
.operator
== "r|":
154 return f
"(0 != {mask(arg)})"
155 if value
.operator
== "r&":
156 return f
"({(1 << len(arg)) - 1} == {mask(arg)})"
157 if value
.operator
== "r^":
158 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
159 return f
"(format({mask(arg)}, 'b').count('1') % 2)"
160 if value
.operator
in ("u", "s"):
161 # These operators don't change the bit pattern, only its interpretation.
163 elif len(value
.operands
) == 2:
164 lhs
, rhs
= value
.operands
165 if value
.operator
== "+":
166 return f
"({sign(lhs)} + {sign(rhs)})"
167 if value
.operator
== "-":
168 return f
"({sign(lhs)} - {sign(rhs)})"
169 if value
.operator
== "*":
170 return f
"({sign(lhs)} * {sign(rhs)})"
171 if value
.operator
== "//":
172 return f
"zdiv({sign(lhs)}, {sign(rhs)})"
173 if value
.operator
== "%":
174 return f
"zmod({sign(lhs)}, {sign(rhs)})"
175 if value
.operator
== "&":
176 return f
"({self(lhs)} & {self(rhs)})"
177 if value
.operator
== "|":
178 return f
"({self(lhs)} | {self(rhs)})"
179 if value
.operator
== "^":
180 return f
"({self(lhs)} ^ {self(rhs)})"
181 if value
.operator
== "<<":
182 return f
"({sign(lhs)} << {sign(rhs)})"
183 if value
.operator
== ">>":
184 return f
"({sign(lhs)} >> {sign(rhs)})"
185 if value
.operator
== "==":
186 return f
"({sign(lhs)} == {sign(rhs)})"
187 if value
.operator
== "!=":
188 return f
"({sign(lhs)} != {sign(rhs)})"
189 if value
.operator
== "<":
190 return f
"({sign(lhs)} < {sign(rhs)})"
191 if value
.operator
== "<=":
192 return f
"({sign(lhs)} <= {sign(rhs)})"
193 if value
.operator
== ">":
194 return f
"({sign(lhs)} > {sign(rhs)})"
195 if value
.operator
== ">=":
196 return f
"({sign(lhs)} >= {sign(rhs)})"
197 elif len(value
.operands
) == 3:
198 if value
.operator
== "m":
199 sel
, val1
, val0
= value
.operands
200 return f
"(({mask(sel)}) ? ({self(val1)}) : ({self(val0)}))"
201 raise NotImplementedError("Operator '{}' not implemented".format(value
.operator
)) # :nocov:
203 def on_Slice(self
, value
):
204 return f
"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
206 def on_Part(self
, value
):
207 offset_mask
= (1 << len(value
.offset
)) - 1
208 offset
= f
"({value.stride} * ({offset_mask} & {self(value.offset)}))"
209 return f
"({(1 << value.width) - 1} & " \
210 f
"{self(value.value)} >> {offset})"
212 def on_Cat(self
, value
):
215 for part
in value
.parts
:
216 part_mask
= (1 << len(part
)) - 1
217 gen_parts
.append(f
"(({part_mask} & {self(part)}) << {offset})")
220 return f
"({' | '.join(gen_parts)})"
223 def on_Repl(self
, value
):
224 part_mask
= (1 << len(value
.value
)) - 1
225 gen_part
= self
.emitter
.def_var("repl", f
"{part_mask} & {self(value.value)}")
228 for _
in range(value
.count
):
229 gen_parts
.append(f
"({gen_part} << {offset})")
230 offset
+= len(value
.value
)
232 return f
"({' | '.join(gen_parts)})"
235 def on_ArrayProxy(self
, value
):
236 index_mask
= (1 << len(value
.index
)) - 1
237 gen_index
= self
.emitter
.def_var("rhs_index", f
"{index_mask} & {self(value.index)}")
238 gen_value
= self
.emitter
.gen_var("rhs_proxy")
240 for index
, elem
in enumerate(value
.elems
):
242 self
.emitter
.if_(f
"{index} == {gen_index}")
244 self
.emitter
.else_if(f
"{index} == {gen_index}")
245 with self
.emitter
.nest():
246 self
.emitter
.assign(f
"{gen_value}", f
"{self(elem)}")
248 with self
.emitter
.nest():
249 self
.emitter
.assign(f
"{gen_value}", f
"{self(value.elems[-1])}")
255 def compile(cls
, state
, value
, *, mode
):
256 emitter
= _PythonEmitter()
257 compiler
= cls(state
, emitter
, mode
=mode
)
258 emitter
.assign(f
"result", f
"{compiler(value)}")
259 return emitter
.flush()
262 class _LHSValueCompiler(_ValueCompiler
):
263 def __init__(self
, state
, emitter
, *, rhs
, outputs
=None):
264 super().__init
__(state
, emitter
)
265 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
266 # the offset of a Part.
268 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
269 # update of an lvalue.
270 self
.lrhs
= _RHSValueCompiler(state
, emitter
, mode
="next", inputs
=None)
271 # If not None, `outputs` gets populated with signals on LHS.
272 self
.outputs
= outputs
274 def on_SmtExpr(self
, value
):
275 raise NotImplementedError
277 def on_Const(self
, value
):
278 raise TypeError # :nocov:
280 def on_Signal(self
, value
):
281 if self
.outputs
is not None:
282 self
.outputs
.add(value
)
285 value_mask
= (1 << len(value
)) - 1
286 if value
.shape().signed
:
287 value_sign
= f
"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
289 value_sign
= f
"{value_mask} & {arg}"
291 macro
= self
.state
.get_signal_macro(value
)
292 self
.emitter
.append(f
"next_{macro} = {value_sign};")
295 def on_Operator(self
, value
):
296 raise TypeError # :nocov:
298 def on_Slice(self
, value
):
300 width_mask
= (1 << (value
.stop
- value
.start
)) - 1
301 self(value
.value
)(f
"(({self.lrhs(value.value)} & " \
302 f
"{~(width_mask << value.start)}) | " \
303 f
"(({width_mask} & {arg}) << {value.start}))")
306 def on_Part(self
, value
):
308 width_mask
= (1 << value
.width
) - 1
309 offset_mask
= (1 << len(value
.offset
)) - 1
310 offset
= f
"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
311 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
312 f
"~({width_mask} << {offset}) | " \
313 f
"(({width_mask} & {arg}) << {offset}))")
316 def on_Cat(self
, value
):
318 gen_arg
= self
.emitter
.def_var("cat", arg
)
320 for part
in value
.parts
:
321 part_mask
= (1 << len(part
)) - 1
322 self(part
)(f
"({part_mask} & ({gen_arg} >> {offset}))")
326 def on_Repl(self
, value
):
327 raise TypeError # :nocov:
329 def on_ArrayProxy(self
, value
):
331 index_mask
= (1 << len(value
.index
)) - 1
332 gen_index
= self
.emitter
.def_var("index", f
"{self.rrhs(value.index)} & {index_mask}")
334 for index
, elem
in enumerate(value
.elems
):
336 self
.emitter
.if_(f
"{index} == {gen_index}")
338 self
.emitter
.append(f
"{index} == {gen_index}")
339 with self
.emitter
.nest():
342 with self
.emitter
.nest():
343 self(value
.elems
[-1])(arg
)
347 class _StatementCompiler(StatementVisitor
, _Compiler
):
348 def __init__(self
, state
, emitter
, *, inputs
=None, outputs
=None):
349 super().__init
__(state
, emitter
)
350 self
.rhs
= _RHSValueCompiler(state
, emitter
, mode
="curr", inputs
=inputs
)
351 self
.lhs
= _LHSValueCompiler(state
, emitter
, rhs
=self
.rhs
, outputs
=outputs
)
353 def on_statements(self
, stmts
):
357 self
.emitter
.append("/* pass */;")
359 def on_Assign(self
, stmt
):
360 gen_rhs
= f
"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
361 if stmt
.rhs
.shape().signed
:
362 gen_rhs
= f
"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
363 return self
.lhs(stmt
.lhs
)(gen_rhs
)
365 def on_Switch(self
, stmt
):
366 gen_test
= self
.emitter
.def_var("test",
367 f
"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
368 for index
, (patterns
, stmts
) in enumerate(stmt
.cases
.items()):
371 gen_checks
.append(f
"1 /* True */")
373 for pattern
in patterns
:
375 mask
= int("".join("0" if b
== "-" else "1" for b
in pattern
), 2)
376 value
= int("".join("0" if b
== "-" else b
for b
in pattern
), 2)
377 gen_checks
.append(f
"{value} == ({mask} & {gen_test})")
379 value
= int(pattern
, 2)
380 gen_checks
.append(f
"{value} == {gen_test}")
382 self
.emitter
.if_(f
"{' || '.join(gen_checks)}")
384 self
.emitter
.else_if(f
"{' || '.join(gen_checks)}")
385 with self
.emitter
.nest():
388 def on_Display(self
, stmt
):
389 raise NotImplementedError # :nocov:
391 def on_Assert(self
, stmt
):
392 raise NotImplementedError # :nocov:
394 def on_Assume(self
, stmt
):
395 raise NotImplementedError # :nocov:
397 def on_Cover(self
, stmt
):
398 raise NotImplementedError # :nocov:
401 def compile(cls
, state
, stmt
):
403 [state
.get_signal_macro(signal
) for signal
in stmt
._lhs
_signals
()]
404 emitter
= _PythonEmitter()
405 for macro
in output_macros
:
406 emitter
.append(f
"uint64_t next_{macro} = slots[{macro}].next")
407 compiler
= cls(state
, emitter
)
409 for macro
in output_macros
:
410 emitter
.append(f
"set({macro}, next_{macro})")
411 return emitter
.flush()
414 class _FragmentCompiler
:
415 def __init__(self
, state
):
418 def __call__(self
, fragment
, fragment_name
):
421 for index
, (domain_name
, domain_signals
) in enumerate(fragment
.drivers
.items()):
422 domain_stmts
= LHSGroupFilter(domain_signals
)(fragment
.statements
)
423 domain_process
= PyRTLProcess(is_comb
=domain_name
is None)
424 domain_process
.name
= \
425 f
"{fragment_name}__{domain_name or ''}" \
426 f
"_{id(fragment)}_{index}"
428 emitter
= _PythonEmitter()
429 emitter
.append(f
"void run_{domain_process.name}(void)")
431 if domain_name
is None:
432 for signal
in domain_signals
:
433 macro
= self
.state
.get_signal_macro(signal
)
435 f
"uint64_t next_{macro} = {signal.reset};")
438 _StatementCompiler(self
.state
, emitter
, inputs
=inputs
)(domain_stmts
)
441 self
.state
.add_trigger(domain_process
, input)
444 domain
= fragment
.domains
[domain_name
]
445 clk_trigger
= 1 if domain
.clk_edge
== "pos" else 0
446 self
.state
.add_trigger(domain_process
, domain
.clk
, trigger
=clk_trigger
)
447 if domain
.rst
is not None and domain
.async_reset
:
449 self
.state
.add_trigger(domain_process
, domain
.rst
, trigger
=rst_trigger
)
451 for signal
in domain_signals
:
452 macro
= self
.state
.get_signal_macro(signal
)
454 f
"uint64_t next_{macro} = slots[{macro}].next;")
456 _StatementCompiler(self
.state
, emitter
)(domain_stmts
)
458 for signal
in domain_signals
:
459 macro
= self
.state
.get_signal_macro(signal
)
460 emitter
.append(f
"set({macro}, next_{macro});")
462 code
= "#include <stdint.h>\n"
463 code
+= "#include \"common.h\"\n"
464 code
+= emitter
.flush()
466 crtl
= get_crtl_path()
468 file = open(os
.path
.join(crtl
, f
"{domain_process.name}.c"), "w")
472 processes
.add(domain_process
)
474 for subfragment_index
, (subfragment
, subfragment_name
) in enumerate(fragment
.subfragments
):
475 if subfragment_name
is None:
476 subfragment_name
= "U${}".format(subfragment_index
)
477 processes
.update(self(subfragment
, subfragment_name
))