back.rtlil: reorganize value compiler into LHS/RHS.
authorwhitequark <cz@m-labs.hk>
Sun, 16 Dec 2018 13:30:20 +0000 (13:30 +0000)
committerwhitequark <cz@m-labs.hk>
Sun, 16 Dec 2018 13:33:34 +0000 (13:33 +0000)
This also implements Cat on LHS.

nmigen/back/rtlil.py

index 9049066c843b3f25a2dea2cc9fd85003061c7c49..5641322031f5585807a021c658d79adb1dd26b71 100644 (file)
@@ -195,37 +195,12 @@ def src(src_loc):
     return "{}:{}".format(file, line)
 
 
-class _ValueTransformer(xfrm.AbstractValueTransformer):
-    operator_map = {
-        (1, "~"):    "$not",
-        (1, "-"):    "$neg",
-        (1, "b"):    "$reduce_bool",
-        (2, "+"):    "$add",
-        (2, "-"):    "$sub",
-        (2, "*"):    "$mul",
-        (2, "/"):    "$div",
-        (2, "%"):    "$mod",
-        (2, "**"):   "$pow",
-        (2, "<<"):   "$sshl",
-        (2, ">>"):   "$sshr",
-        (2, "&"):    "$and",
-        (2, "^"):    "$xor",
-        (2, "|"):    "$or",
-        (2, "=="):   "$eq",
-        (2, "!="):   "$ne",
-        (2, "<"):    "$lt",
-        (2, "<="):   "$le",
-        (2, ">"):    "$gt",
-        (2, ">="):   "$ge",
-        (3, "m"):    "$mux",
-    }
-
+class _ValueCompilerState:
     def __init__(self, rtlil):
-        self.rtlil  = rtlil
-        self.wires  = ast.ValueDict()
-        self.driven = ast.ValueDict()
-        self.ports  = ast.ValueDict()
-        self.is_lhs   = False
+        self.rtlil    = rtlil
+        self.wires    = ast.ValueDict()
+        self.driven   = ast.ValueDict()
+        self.ports    = ast.ValueDict()
         self.sub_name = None
 
     def add_driven(self, signal, sync):
@@ -241,13 +216,36 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
             kind = "inout"
         self.ports[signal] = (len(self.ports), kind)
 
-    @contextmanager
-    def lhs(self):
-        try:
-            self.is_lhs = True
-            yield
-        finally:
-            self.is_lhs = False
+    def resolve(self, signal):
+        if signal in self.wires:
+            return self.wires[signal]
+
+        if signal in self.ports:
+            port_id, port_kind = self.ports[signal]
+        else:
+            port_id = port_kind = None
+        if self.sub_name:
+            wire_name = "{}_{}".format(self.sub_name, signal.name)
+        else:
+            wire_name = signal.name
+
+        for attr_name, attr_signal in signal.attrs.items():
+            self.rtlil.attribute(attr_name, attr_signal)
+        wire_curr = self.rtlil.wire(width=signal.nbits, name=wire_name,
+                                    port_id=port_id, port_kind=port_kind,
+                                    src=src(signal.src_loc))
+        if signal in self.driven:
+            wire_next = self.rtlil.wire(width=signal.nbits, name=wire_curr + "$next",
+                                        src=src(signal.src_loc))
+        else:
+            wire_next = None
+        self.wires[signal] = (wire_curr, wire_next)
+
+        return wire_curr, wire_next
+
+    def resolve_curr(self, signal):
+        wire_curr, wire_next = self.resolve(signal)
+        return wire_curr
 
     @contextmanager
     def hierarchy(self, sub_name):
@@ -257,12 +255,58 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
         finally:
             self.sub_name = None
 
+
+class _ValueCompiler(xfrm.AbstractValueTransformer):
+    def __init__(self, state):
+        self.s = state
+
     def on_unknown(self, value):
         if value is None:
             return None
         else:
             super().on_unknown(value)
 
+    def on_ClockSignal(self, value):
+        raise NotImplementedError # :nocov:
+
+    def on_ResetSignal(self, value):
+        raise NotImplementedError # :nocov:
+
+    def on_Slice(self, value):
+        if value.end == value.start + 1:
+            return "{} [{}]".format(self(value.value), value.start)
+        else:
+            return "{} [{}:{}]".format(self(value.value), value.end - 1, value.start)
+
+    def on_Cat(self, value):
+        return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.operands])))
+
+
+class _RHSValueCompiler(_ValueCompiler):
+    operator_map = {
+        (1, "~"):    "$not",
+        (1, "-"):    "$neg",
+        (1, "b"):    "$reduce_bool",
+        (2, "+"):    "$add",
+        (2, "-"):    "$sub",
+        (2, "*"):    "$mul",
+        (2, "/"):    "$div",
+        (2, "%"):    "$mod",
+        (2, "**"):   "$pow",
+        (2, "<<"):   "$sshl",
+        (2, ">>"):   "$sshr",
+        (2, "&"):    "$and",
+        (2, "^"):    "$xor",
+        (2, "|"):    "$or",
+        (2, "=="):   "$eq",
+        (2, "!="):   "$ne",
+        (2, "<"):    "$lt",
+        (2, "<="):   "$le",
+        (2, ">"):    "$gt",
+        (2, ">="):   "$ge",
+        (3, "m"):    "$mux",
+    }
+
     def on_Const(self, value):
         if isinstance(value.value, str):
             return "{}'{}".format(value.nbits, value.value)
@@ -270,48 +314,15 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
             return "{}'{:b}".format(value.nbits, value.value)
 
     def on_Signal(self, value):
-        if value in self.wires:
-            wire_curr, wire_next = self.wires[value]
-        else:
-            if value in self.ports:
-                port_id, port_kind = self.ports[value]
-            else:
-                port_id = port_kind = None
-            if self.sub_name:
-                wire_name = "{}_{}".format(self.sub_name, value.name)
-            else:
-                wire_name = value.name
-            for attr_name, attr_value in value.attrs.items():
-                self.rtlil.attribute(attr_name, attr_value)
-            wire_curr = self.rtlil.wire(width=value.nbits, name=wire_name,
-                                        port_id=port_id, port_kind=port_kind,
-                                        src=src(value.src_loc))
-            if value in self.driven:
-                wire_next = self.rtlil.wire(width=value.nbits, name=wire_curr + "$next",
-                                            src=src(value.src_loc))
-            else:
-                wire_next = None
-            self.wires[value] = (wire_curr, wire_next)
-
-        if self.is_lhs:
-            if wire_next is None:
-                raise ValueError("Cannot return lhs for non-driven signal {}".format(repr(value)))
-            return wire_next
-        else:
-            return wire_curr
-
-    def on_ClockSignal(self, value):
-        raise NotImplementedError # :nocov:
-
-    def on_ResetSignal(self, value):
-        raise NotImplementedError # :nocov:
+        wire_curr, wire_next = self.s.resolve(value)
+        return wire_curr
 
     def on_Operator_unary(self, value):
         arg, = value.operands
         arg_bits, arg_sign = arg.shape()
         res_bits, res_sign = value.shape()
-        res = self.rtlil.wire(width=res_bits)
-        self.rtlil.cell(self.operator_map[(1, value.op)], ports={
+        res = self.s.rtlil.wire(width=res_bits)
+        self.s.rtlil.cell(self.operator_map[(1, value.op)], ports={
             "\\A": self(arg),
             "\\Y": res,
         }, params={
@@ -327,8 +338,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
 
         value_bits, value_sign = value.shape()
         if new_bits > value_bits:
-            res = self.rtlil.wire(width=new_bits)
-            self.rtlil.cell("$pos", ports={
+            res = self.s.rtlil.wire(width=new_bits)
+            self.s.rtlil.cell("$pos", ports={
                 "\\A": self(value),
                 "\\Y": res,
             }, params={
@@ -353,8 +364,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
             lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign)
             rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign)
         res_bits, res_sign = value.shape()
-        res = self.rtlil.wire(width=res_bits)
-        self.rtlil.cell(self.operator_map[(2, value.op)], ports={
+        res = self.s.rtlil.wire(width=res_bits)
+        self.s.rtlil.cell(self.operator_map[(2, value.op)], ports={
             "\\A": lhs_wire,
             "\\B": rhs_wire,
             "\\Y": res,
@@ -375,8 +386,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
         lhs_bits = rhs_bits = res_bits = max(lhs_bits, rhs_bits, res_bits)
         lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign)
         rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign)
-        res = self.rtlil.wire(width=res_bits)
-        self.rtlil.cell("$mux", ports={
+        res = self.s.rtlil.wire(width=res_bits)
+        self.s.rtlil.cell("$mux", ports={
             "\\A": lhs_wire,
             "\\B": rhs_wire,
             "\\S": self(sel),
@@ -395,20 +406,11 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
             assert value.op == "m"
             return self.on_Operator_mux(value)
         else:
-            raise TypeError
-
-    def on_Slice(self, value):
-        if value.end == value.start + 1:
-            return "{} [{}]".format(self(value.value), value.start)
-        else:
-            return "{} [{}:{}]".format(self(value.value), value.end - 1, value.start)
+            raise TypeError # :nocov:
 
     def on_Part(self, value):
         raise NotImplementedError
 
-    def on_Cat(self, value):
-        return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.operands])))
-
     def on_Repl(self, value):
         return "{{ {} }}".format(" ".join(self(value.value) for _ in range(value.count)))
 
@@ -416,27 +418,52 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
         raise NotImplementedError
 
 
+class _LHSValueCompiler(_ValueCompiler):
+    def on_Const(self, value):
+        raise TypeError # :nocov:
+
+    def on_Operator(self, value):
+        raise TypeError # :nocov:
+
+    def on_Signal(self, value):
+        wire_curr, wire_next = self.s.resolve(value)
+        if wire_next is None:
+            raise ValueError("Cannot return lhs for non-driven signal {}".format(repr(value)))
+        return wire_next
+
+    def on_Part(self, value):
+        raise NotImplementedError
+
+    def on_Repl(self, value):
+        raise TypeError # :nocov:
+
+    def on_ArrayProxy(self, value):
+        raise NotImplementedError
+
+
 def convert_fragment(builder, fragment, name, top):
     with builder.module(name or "anonymous", attrs={"top": 1} if top else {}) as module:
-        xformer = _ValueTransformer(module)
+        compiler_state = _ValueCompilerState(module)
+        rhs_compiler   = _RHSValueCompiler(compiler_state)
+        lhs_compiler   = _LHSValueCompiler(compiler_state)
 
         # Register all signals driven in the current fragment. This must be done first, as it
         # affects further codegen; e.g. whether sig$next signals will be generated and used.
         for domain, signal in fragment.iter_drivers():
-            xformer.add_driven(signal, sync=domain is not None)
+            compiler_state.add_driven(signal, sync=domain is not None)
 
         # Transform all signals used as ports in the current fragment eagerly and outside of
         # any hierarchy, to make sure they get sensible (non-prefixed) names.
         for signal in fragment.ports:
-            xformer.add_port(signal, fragment.ports[signal])
-            xformer(signal)
+            compiler_state.add_port(signal, fragment.ports[signal])
+            rhs_compiler(signal)
 
         # Transform all clocks clocks and resets eagerly and outside of any hierarchy, to make
         # sure they get sensible (non-prefixed) names. This does not affect semantics.
         for domain, _ in fragment.iter_sync():
             cd = fragment.domains[domain]
-            xformer(cd.clk)
-            xformer(cd.rst)
+            rhs_compiler(cd.clk)
+            rhs_compiler(cd.rst)
 
         # Transform all subfragments to their respective cells. Transforming signals connected
         # to their ports into wires eagerly makes sure they get sensible (prefixed with submodule
@@ -444,9 +471,9 @@ def convert_fragment(builder, fragment, name, top):
         for subfragment, sub_name in fragment.subfragments:
             sub_name, sub_port_map = \
                 convert_fragment(builder, subfragment, top=False, name=sub_name)
-            with xformer.hierarchy(sub_name):
+            with compiler_state.hierarchy(sub_name):
                 module.cell(sub_name, name=sub_name, ports={
-                    p: xformer(s) for p, s in sub_port_map.items()
+                    p: rhs_compiler(s) for p, s in sub_port_map.items()
                 })
 
         with module.process() as process:
@@ -455,11 +482,10 @@ def convert_fragment(builder, fragment, name, top):
                 # For every signal in sync domains, assign \sig$next to the current value (\sig).
                 for domain, signal in fragment.iter_drivers():
                     if domain is None:
-                        prev_value = xformer(ast.Const(signal.reset, signal.nbits))
+                        prev_value = ast.Const(signal.reset, signal.nbits)
                     else:
-                        prev_value = xformer(signal)
-                    with xformer.lhs():
-                        case.assign(xformer(signal), prev_value)
+                        prev_value = signal
+                    case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
 
                 # Convert statements into decision trees.
                 def _convert_stmts(case, stmts):
@@ -468,17 +494,15 @@ def convert_fragment(builder, fragment, name, top):
                             lhs_bits, lhs_sign = stmt.lhs.shape()
                             rhs_bits, rhs_sign = stmt.rhs.shape()
                             if lhs_bits == rhs_bits:
-                                rhs_sigspec = xformer(stmt.rhs)
+                                rhs_sigspec = rhs_compiler(stmt.rhs)
                             else:
                                 # In RTLIL, LHS and RHS of assignment must have exactly same width.
-                                rhs_sigspec = xformer.match_shape(
+                                rhs_sigspec = rhs_compiler.match_shape(
                                     stmt.rhs, lhs_bits, rhs_sign)
-                            with xformer.lhs():
-                                lhs_sigspec = xformer(stmt.lhs)
-                            case.assign(lhs_sigspec, rhs_sigspec)
+                            case.assign(lhs_compiler(stmt.lhs), rhs_sigspec)
 
                         elif isinstance(stmt, ast.Switch):
-                            with case.switch(xformer(stmt.test)) as switch:
+                            with case.switch(rhs_compiler(stmt.test)) as switch:
                                 for value, nested_stmts in stmt.cases.items():
                                     with switch.case(value) as nested_case:
                                         _convert_stmts(nested_case, nested_stmts)
@@ -489,12 +513,11 @@ def convert_fragment(builder, fragment, name, top):
                 _convert_stmts(case, fragment.statements)
 
             # For every signal in the sync domain, assign \sig's initial value (which will end up
-            # as the \init reg attribute) to the reset value. Note that this assigns \sig,
-            # not \sig$next.
+            # as the \init reg attribute) to the reset value.
             with process.sync("init") as sync:
                 for domain, signal in fragment.iter_sync():
-                    sync.update(xformer(signal),
-                                xformer(ast.Const(signal.reset, signal.nbits)))
+                    wire_curr, wire_next = compiler_state.resolve(signal)
+                    sync.update(wire_curr, rhs_compiler(ast.Const(signal.reset, signal.nbits)))
 
             # For every signal in every domain, assign \sig to \sig$next. The sensitivity list,
             # however, differs between domains: for comb domains, it is `always`, for sync domains
@@ -506,23 +529,22 @@ def convert_fragment(builder, fragment, name, top):
                     triggers.append(("always",))
                 else:
                     cd = fragment.domains[domain]
-                    triggers.append(("posedge", xformer(cd.clk)))
+                    triggers.append(("posedge", compiler_state.resolve_curr(cd.clk)))
                     if cd.async_reset:
-                        triggers.append(("posedge", xformer(cd.rst)))
+                        triggers.append(("posedge", compiler_state.resolve_curr(cd.rst)))
 
                 for trigger in triggers:
                     with process.sync(*trigger) as sync:
                         for signal in signals:
-                            lhs_sigspec = xformer(signal)
-                            with xformer.lhs():
-                                sync.update(lhs_sigspec, xformer(signal))
+                            wire_curr, wire_next = compiler_state.resolve(signal)
+                            sync.update(wire_curr, wire_next)
 
     # Finally, collect the names we've given to our ports in RTLIL, and correlate these with
     # the signals represented by these ports. If we are a submodule, this will be necessary
     # to create a cell for us in the parent module.
     port_map = OrderedDict()
     for signal in fragment.ports:
-        port_map[xformer(signal)] = signal
+        port_map[compiler_state.resolve_curr(signal)] = signal
 
     return module.name, port_map