back.rtlil: fix legalization of Part() with stride.
[nmigen.git] / nmigen / back / rtlil.py
index 298d2f9fc5eb17d1c9c569843e3bb7adc70281e8..c074a97f5042fd8c9a04c393192d7efc7077d2ef 100644 (file)
@@ -3,7 +3,7 @@ import textwrap
 from collections import defaultdict, OrderedDict
 from contextlib import contextmanager
 
-from ..tools import bits_for, flatten
+from .._utils import bits_for, flatten
 from ..hdl import ast, rec, ir, mem, xfrm
 
 
@@ -128,8 +128,8 @@ class _ModuleBuilder(_Namer, _BufferedBuilder, _AttrBuilder):
                 self._append("    parameter \\{} \"{}\"\n",
                              param, value.translate(self._escape_map))
             elif isinstance(value, int):
-                self._append("    parameter \\{} {:d}\n",
-                             param, value)
+                self._append("    parameter \\{} {}'{:b}\n",
+                             param, bits_for(value), value)
             elif isinstance(value, float):
                 self._append("    parameter real \\{} \"{!r}\"\n",
                              param, value)
@@ -278,6 +278,9 @@ class _ValueCompilerState:
         self.ports[signal] = (len(self.ports), kind)
 
     def resolve(self, signal, prefix=None):
+        if len(signal) == 0:
+            return "{ }", "{ }"
+
         if signal in self.wires:
             return self.wires[signal]
 
@@ -354,16 +357,16 @@ class _ValueCompiler(xfrm.ValueVisitor):
         raise NotImplementedError # :nocov:
 
     def on_Slice(self, value):
-        if value.start == 0 and value.end == len(value.value):
+        if value.start == 0 and value.stop == len(value.value):
             return self(value.value)
 
         sigspec = self._prepare_value_for_Slice(value.value)
-        if value.start == value.end:
+        if value.start == value.stop:
             return "{}"
-        elif value.start + 1 == value.end:
+        elif value.start + 1 == value.stop:
             return "{} [{}]".format(sigspec, value.start)
         else:
-            return "{} [{}:{}]".format(sigspec, value.end - 1, value.start)
+            return "{} [{}:{}]".format(sigspec, value.stop - 1, value.start)
 
     def on_ArrayProxy(self, value):
         index = self.s.expand(value.index)
@@ -374,7 +377,9 @@ class _ValueCompiler(xfrm.ValueVisitor):
                 elem = value.elems[-1]
             return self.match_shape(elem, *value.shape())
         else:
-            raise LegalizeValue(value.index, range(len(value.elems)), value.src_loc)
+            max_index = 1 << len(value.index)
+            max_elem  = len(value.elems)
+            raise LegalizeValue(value.index, range(min(max_index, max_elem)), value.src_loc)
 
 
 class _RHSValueCompiler(_ValueCompiler):
@@ -449,10 +454,14 @@ class _RHSValueCompiler(_ValueCompiler):
 
     def on_Operator_unary(self, value):
         arg, = value.operands
+        if value.operator in ("u", "s"):
+            # These operators don't change the bit pattern, only its interpretation.
+            return self(arg)
+
         arg_bits, arg_sign = arg.shape()
         res_bits, res_sign = value.shape()
         res = self.s.rtlil.wire(width=res_bits, src=src(value.src_loc))
-        self.s.rtlil.cell(self.operator_map[(1, value.op)], ports={
+        self.s.rtlil.cell(self.operator_map[(1, value.operator)], ports={
             "\\A": self(arg),
             "\\Y": res,
         }, params={
@@ -464,7 +473,7 @@ class _RHSValueCompiler(_ValueCompiler):
 
     def match_shape(self, value, new_bits, new_sign):
         if isinstance(value, ast.Const):
-            return self(ast.Const(value.value, (new_bits, new_sign)))
+            return self(ast.Const(value.value, ast.Shape(new_bits, new_sign)))
 
         value_bits, value_sign = value.shape()
         if new_bits <= value_bits:
@@ -485,16 +494,17 @@ class _RHSValueCompiler(_ValueCompiler):
         lhs, rhs = value.operands
         lhs_bits, lhs_sign = lhs.shape()
         rhs_bits, rhs_sign = rhs.shape()
-        if lhs_sign == rhs_sign:
+        if lhs_sign == rhs_sign or value.operator in ("<<", ">>", "**"):
             lhs_wire = self(lhs)
             rhs_wire = self(rhs)
         else:
             lhs_sign = rhs_sign = True
+            lhs_bits = rhs_bits = max(lhs_bits, rhs_bits)
             lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign)
             rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign)
         res_bits, res_sign = value.shape()
         res = self.s.rtlil.wire(width=res_bits, src=src(value.src_loc))
-        self.s.rtlil.cell(self.operator_map[(2, value.op)], ports={
+        self.s.rtlil.cell(self.operator_map[(2, value.operator)], ports={
             "\\A": lhs_wire,
             "\\B": rhs_wire,
             "\\Y": res,
@@ -505,6 +515,18 @@ class _RHSValueCompiler(_ValueCompiler):
             "B_WIDTH": rhs_bits,
             "Y_WIDTH": res_bits,
         }, src=src(value.src_loc))
+        if value.operator in ("//", "%"):
+            # RTLIL leaves division by zero undefined, but we require it to return zero.
+            divmod_res = res
+            res = self.s.rtlil.wire(width=res_bits, src=src(value.src_loc))
+            self.s.rtlil.cell("$mux", ports={
+                "\\A": divmod_res,
+                "\\B": self(ast.Const(0, ast.Shape(res_bits, res_sign))),
+                "\\S": self(lhs == 0),
+                "\\Y": res,
+            }, params={
+                "WIDTH": res_bits
+            }, src=src(value.src_loc))
         return res
 
     def on_Operator_mux(self, value):
@@ -532,7 +554,7 @@ class _RHSValueCompiler(_ValueCompiler):
         elif len(value.operands) == 2:
             return self.on_Operator_binary(value)
         elif len(value.operands) == 3:
-            assert value.op == "m"
+            assert value.operator == "m"
             return self.on_Operator_mux(value)
         else:
             raise TypeError # :nocov:
@@ -554,7 +576,7 @@ class _RHSValueCompiler(_ValueCompiler):
         res_bits, res_sign = value.shape()
         res = self.s.rtlil.wire(width=res_bits, src=src(value.src_loc))
         # Note: Verilog's x[o+:w] construct produces a $shiftx cell, not a $shift cell.
-        # However, Migen's semantics defines the out-of-range bits to be zero, so it is correct
+        # However, nMigen's semantics defines the out-of-range bits to be zero, so it is correct
         # to use a $shift cell here instead, even though it produces less idiomatic Verilog.
         self.s.rtlil.cell("$shift", ports={
             "\\A": self(lhs),
@@ -610,9 +632,20 @@ class _LHSValueCompiler(_ValueCompiler):
     def on_Part(self, value):
         offset = self.s.expand(value.offset)
         if isinstance(offset, ast.Const):
-            return self(ast.Slice(value.value, offset.value, offset.value + value.width))
+            if offset.value == len(value.value):
+                dummy_wire = self.s.rtlil.wire(value.width)
+                return dummy_wire
+            return self(ast.Slice(value.value,
+                                  offset.value * value.stride,
+                                  offset.value * value.stride + value.width))
         else:
-            raise LegalizeValue(value.offset, range((1 << len(value.offset))), value.src_loc)
+            # Only so many possible parts. The amount of branches is exponential; if value.offset
+            # is large (e.g. 32-bit wide), trying to naively legalize it is likely to exhaust
+            # system resources.
+            max_branches = len(value.value) // value.stride + 1
+            raise LegalizeValue(value.offset,
+                                range(1 << len(value.offset))[:max_branches],
+                                value.src_loc)
 
     def on_Repl(self, value):
         raise TypeError # :nocov:
@@ -681,9 +714,17 @@ class _StatementCompiler(xfrm.StatementVisitor):
     def on_Switch(self, stmt):
         self._check_rhs(stmt.test)
 
-        if stmt not in self._test_cache:
-            self._test_cache[stmt] = self.rhs_compiler(stmt.test)
-        test_sigspec = self._test_cache[stmt]
+        if not self.state.expansions:
+            # We repeatedly translate the same switches over and over (see the LHSGroupAnalyzer
+            # related code below), and translating the switch test only once helps readability.
+            if stmt not in self._test_cache:
+                self._test_cache[stmt] = self.rhs_compiler(stmt.test)
+            test_sigspec = self._test_cache[stmt]
+        else:
+            # However, if the switch test contains an illegal value, then it may not be cached
+            # (since the illegal value will be repeatedly replaced with different constants), so
+            # don't cache anything in that case.
+            test_sigspec = self.rhs_compiler(stmt.test)
 
         with self._case.switch(test_sigspec, src=src(stmt.src_loc)) as switch:
             for values, stmts in stmt.cases.items():
@@ -709,13 +750,14 @@ class _StatementCompiler(xfrm.StatementVisitor):
         except LegalizeValue as legalize:
             with self._case.switch(self.rhs_compiler(legalize.value),
                                    src=src(legalize.src_loc)) as switch:
-                width, signed = legalize.value.shape()
-                tests = ["{:0{}b}".format(v, width) for v in legalize.branches]
-                tests[-1] = "-" * width
+                shape = legalize.value.shape()
+                tests = ["{:0{}b}".format(v, shape.width) for v in legalize.branches]
+                if tests:
+                    tests[-1] = "-" * shape.width
                 for branch, test in zip(legalize.branches, tests):
                     with self.case(switch, (test,)):
                         self._wrap_assign = False
-                        branch_value = ast.Const(branch, (width, signed))
+                        branch_value = ast.Const(branch, shape)
                         with self.state.expand_to(legalize.value, branch_value):
                             self.on_statement(stmt)
             self._wrap_assign = True
@@ -788,7 +830,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
                         memory = param_value
                         if memory not in memories:
                             memories[memory] = module.memory(width=memory.width, size=memory.depth,
-                                                             name=memory.name)
+                                                             name=memory.name, attrs=memory.attrs)
                             addr_bits = bits_for(memory.depth)
                             data_parts = []
                             data_mask = (1 << memory.width) - 1
@@ -867,6 +909,14 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
                     # by looking for any signals on RHS. If there aren't any, we add some logic
                     # whose only purpose is to trigger Verilog simulators when it converts
                     # through RTLIL and to Verilog, by populating the sensitivity list.
+                    #
+                    # Unfortunately, while this workaround allows true (event-driven) Verilog
+                    # simulators to work properly, and is universally ignored by synthesizers,
+                    # Verilator rejects it.
+                    #
+                    # Running the Yosys proc_prune pass converts such pathological `always @*`
+                    # blocks to `assign` statements, so this workaround can be removed completely
+                    # once support for Yosys 0.9 is dropped.
                     if not stmt_compiler._has_rhs:
                         if verilog_trigger is None:
                             verilog_trigger = \