Give human-readable names to slots, run functions and filenames
[openpower-isa.git] / src / openpower / decoder / test / _pyrtl.py
index e559498959b7863077abc8d31fd8dcd7904bfa19..a9e7749ce67071e5a84a087913c5e0aebca07908 100644 (file)
@@ -1,22 +1,18 @@
 import os
-import tempfile
 from contextlib import contextmanager
 
-from nmigen.hdl import *
 from nmigen.hdl.ast import SignalSet
 from nmigen.hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
 from nmigen.sim._base import BaseProcess
 
-
 __all__ = ["PyRTLProcess"]
 
 
 class PyRTLProcess(BaseProcess):
-    __slots__ = ("is_comb", "runnable", "passive", "run")
+    __slots__ = ("is_comb", "runnable", "passive", "name", "crtl", "run")
 
     def __init__(self, *, is_comb):
         self.is_comb  = is_comb
-
         self.reset()
 
     def reset(self):
@@ -41,6 +37,15 @@ class _PythonEmitter:
         yield
         self._level -= 1
 
+    @contextmanager
+    def nest(self):
+        self.append(f"{{")
+        self._level += 1
+        #yield self.indent()
+        yield
+        self._level -= 1
+        self.append(f"}}")
+
     def flush(self, indent=""):
         code = "".join(self._buffer)
         self._buffer.clear()
@@ -53,9 +58,21 @@ class _PythonEmitter:
 
     def def_var(self, prefix, value):
         name = self.gen_var(prefix)
-        self.append(f"{name} = {value}")
+        self.append(f"uint64_t {name} = {value};")
         return name
 
+    def assign(self, lhs, rhs):
+        self.append(f"{lhs} = {rhs}")
+
+    def if_(self, cond):
+        self.append(f"if ({cond})")
+
+    def else_if(self, cond):
+        self.append(f"else if ({cond})")
+
+    def else_(self):
+        self.append(f"else")
+
 
 class _Compiler:
     def __init__(self, state, emitter):
@@ -104,10 +121,11 @@ class _RHSValueCompiler(_ValueCompiler):
         if self.inputs is not None:
             self.inputs.add(value)
 
+        macro = self.state.get_signal_macro(value)
         if self.mode == "curr":
-            return f"slots[{self.state.get_signal(value)}].{self.mode}"
+            return f"slots[{macro}].{self.mode}"
         else:
-            return f"next_{self.state.get_signal(value)}"
+            return f"next_{macro}"
 
     def on_Operator(self, value):
         def mask(value):
@@ -175,7 +193,7 @@ class _RHSValueCompiler(_ValueCompiler):
         elif len(value.operands) == 3:
             if value.operator == "m":
                 sel, val1, val0 = value.operands
-                return f"({self(val1)} if {mask(sel)} else {self(val0)})"
+                return f"(({mask(sel)}) ? ({self(val1)}) : ({self(val0)}))"
         raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
 
     def on_Slice(self, value):
@@ -217,14 +235,14 @@ class _RHSValueCompiler(_ValueCompiler):
         if value.elems:
             for index, elem in enumerate(value.elems):
                 if index == 0:
-                    self.emitter.append(f"if {index} == {gen_index}:")
+                    self.emitter.if_(f"{index} == {gen_index}")
                 else:
-                    self.emitter.append(f"elif {index} == {gen_index}:")
-                with self.emitter.indent():
-                    self.emitter.append(f"{gen_value} = {self(elem)}")
-            self.emitter.append(f"else:")
-            with self.emitter.indent():
-                self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
+                    self.emitter.else_if(f"{index} == {gen_index}")
+                with self.emitter.nest():
+                    self.emitter.assign(f"{gen_value}", f"{self(elem)}")
+            self.emitter.else_()
+            with self.emitter.nest():
+                self.emitter.assign(f"{gen_value}", f"{self(value.elems[-1])}")
             return gen_value
         else:
             return f"0"
@@ -233,7 +251,7 @@ class _RHSValueCompiler(_ValueCompiler):
     def compile(cls, state, value, *, mode):
         emitter = _PythonEmitter()
         compiler = cls(state, emitter, mode=mode)
-        emitter.append(f"result = {compiler(value)}")
+        emitter.assign(f"result", f"{compiler(value)}")
         return emitter.flush()
 
 
@@ -262,7 +280,9 @@ class _LHSValueCompiler(_ValueCompiler):
                 value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
             else: # unsigned
                 value_sign = f"{value_mask} & {arg}"
-            self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
+            
+            macro = self.state.get_signal_macro(value)
+            self.emitter.append(f"next_{macro} = {value_sign};")
         return gen
 
     def on_Operator(self, value):
@@ -306,16 +326,14 @@ class _LHSValueCompiler(_ValueCompiler):
             if value.elems:
                 for index, elem in enumerate(value.elems):
                     if index == 0:
-                        self.emitter.append(f"if {index} == {gen_index}:")
+                        self.emitter.if_(f"{index} == {gen_index}")
                     else:
-                        self.emitter.append(f"elif {index} == {gen_index}:")
-                    with self.emitter.indent():
+                        self.emitter.append(f"{index} == {gen_index}")
+                    with self.emitter.nest():
                         self(elem)(arg)
-                self.emitter.append(f"else:")
-                with self.emitter.indent():
+                self.emitter.else_
+                with self.emitter.nest():
                     self(value.elems[-1])(arg)
-            else:
-                self.emitter.append(f"pass")
         return gen
 
 
@@ -354,12 +372,15 @@ class _StatementCompiler(StatementVisitor, _Compiler):
                         value = int(pattern, 2)
                         gen_checks.append(f"{value} == {gen_test}")
             if index == 0:
-                self.emitter.append(f"if {' or '.join(gen_checks)}:")
+                self.emitter.if_(f"{' or '.join(gen_checks)}")
             else:
-                self.emitter.append(f"elif {' or '.join(gen_checks)}:")
-            with self.emitter.indent():
+                self.emitter.else_if(f"{' or '.join(gen_checks)}")
+            with self.emitter.nest():
                 self(stmts)
 
+    def on_Display(self, stmt):
+        raise NotImplementedError # :nocov:
+
     def on_Assert(self, stmt):
         raise NotImplementedError # :nocov:
 
@@ -371,14 +392,15 @@ class _StatementCompiler(StatementVisitor, _Compiler):
 
     @classmethod
     def compile(cls, state, stmt):
-        output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
+        output_macros = \
+            [state.get_signal_macro(signal) for signal in stmt._lhs_signals()]
         emitter = _PythonEmitter()
-        for signal_index in output_indexes:
-            emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
+        for macro in output_macros:
+            emitter.append(f"uint64_t next_{macro} = slots[{macro}].next")
         compiler = cls(state, emitter)
         compiler(stmt)
-        for signal_index in output_indexes:
-            emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
+        for macro in output_macros:
+            emitter.append(f"set({macro}, next_{macro})")
         return emitter.flush()
 
 
@@ -386,66 +408,63 @@ class _FragmentCompiler:
     def __init__(self, state):
         self.state = state
 
-    def __call__(self, fragment):
+    def __call__(self, fragment, fragment_name):
         processes = set()
 
-        for domain_name, domain_signals in fragment.drivers.items():
+        for index, (domain_name, domain_signals) in enumerate(fragment.drivers.items()):
             domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
             domain_process = PyRTLProcess(is_comb=domain_name is None)
+            domain_process.name = \
+                f"{fragment_name}__{domain_name or ''}" \
+                f"_{id(fragment)}_{index}"
 
             emitter = _PythonEmitter()
-            emitter.append(f"def run():")
-            emitter._level += 1
+            emitter.append(f"void run_{domain_process.name}(void)")
+            with emitter.nest():
+                if domain_name is None:
+                    for signal in domain_signals:
+                        macro = self.state.get_signal_macro(signal)
+                        emitter.append(
+                            f"uint64_t next_{macro} = {signal.reset};")
 
-            if domain_name is None:
-                for signal in domain_signals:
-                    signal_index = self.state.get_signal(signal)
-                    emitter.append(f"next_{signal_index} = {signal.reset}")
+                    inputs = SignalSet()
+                    _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
 
-                inputs = SignalSet()
-                _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
+                    for input in inputs:
+                        self.state.add_trigger(domain_process, input)
 
-                for input in inputs:
-                    self.state.add_trigger(domain_process, input)
+                else:
+                    domain = fragment.domains[domain_name]
+                    clk_trigger = 1 if domain.clk_edge == "pos" else 0
+                    self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
+                    if domain.rst is not None and domain.async_reset:
+                        rst_trigger = 1
+                        self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
 
-            else:
-                domain = fragment.domains[domain_name]
-                clk_trigger = 1 if domain.clk_edge == "pos" else 0
-                self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
-                if domain.rst is not None and domain.async_reset:
-                    rst_trigger = 1
-                    self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
+                    for signal in domain_signals:
+                        macro = self.state.get_signal_macro(signal)
+                        emitter.append(
+                            f"uint64_t next_{macro} = slots[{macro}].next;")
+
+                    _StatementCompiler(self.state, emitter)(domain_stmts)
 
                 for signal in domain_signals:
-                    signal_index = self.state.get_signal(signal)
-                    emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
-
-                _StatementCompiler(self.state, emitter)(domain_stmts)
-
-            for signal in domain_signals:
-                signal_index = self.state.get_signal(signal)
-                emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
-
-            # There shouldn't be any exceptions raised by the generated code, but if there are
-            # (almost certainly due to a bug in the code generator), use this environment variable
-            # to make backtraces useful.
-            code = emitter.flush()
-            if os.getenv("NMIGEN_pysim_dump"):
-                file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
-                file.write(code)
-                filename = file.name
-            else:
-                filename = "<string>"
+                    macro = self.state.get_signal_macro(signal)
+                    emitter.append(f"set({macro}, next_{macro});")
+
+            code = "#include <stdint.h>\n"
+            code += "#include \"common.h\"\n"
+            code += emitter.flush()
 
-            exec_locals = {"slots": self.state.slots, **_ValueCompiler.helpers}
-            exec(compile(code, filename, "exec"), exec_locals)
-            domain_process.run = exec_locals["run"]
+            file = open(f"crtl/{domain_process.name}.c", "w")
+            file.write(code)
+            file.close()
 
             processes.add(domain_process)
 
         for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
             if subfragment_name is None:
                 subfragment_name = "U${}".format(subfragment_index)
-            processes.update(self(subfragment))
+            processes.update(self(subfragment, subfragment_name))
 
         return processes