back.rtlil: implement Part.
authorwhitequark <whitequark@whitequark.org>
Mon, 17 Dec 2018 01:05:08 +0000 (01:05 +0000)
committerwhitequark <whitequark@whitequark.org>
Mon, 17 Dec 2018 01:05:08 +0000 (01:05 +0000)
doc/COMPAT_SUMMARY.md
nmigen/back/rtlil.py
nmigen/hdl/ast.py

index c2180860be9998f522a136821c8216b2b9b75dc6..70f15454ba4748c92bd92890965d06eef01bcf29 100644 (file)
@@ -63,7 +63,7 @@ Compatibility summary
         <br>Note: values no longer valid as keys in `dict` and `set`; use `ValueDict` and `ValueSet` instead.
       - (+) `wrap` → `Value.wrap`
       - (+) `_Operator` → `Operator`
-      - (+) `Mux` → `Mux`
+      - (+) `Mux` id
       - (+) `_Slice` → `Slice`, `stop=`→`end=`, `.stop`→`.end`
       - (+) `_Part` → `Part`
       - (+) `Cat` id, `.l`→`.operands`
index 67ffdd3e8491e5bfe467a7c5b2dd670a521f5c68..3965511702a0318ead64b2e3d2956ff67df9be1f 100644 (file)
@@ -204,12 +204,20 @@ def src(src_loc):
     return "{}:{}".format(file, line)
 
 
+class LegalizeValue(Exception):
+    def __init__(self, value, branches):
+        self.value    = value
+        self.branches = list(branches)
+
+
 class _ValueCompilerState:
     def __init__(self, rtlil):
-        self.rtlil    = rtlil
-        self.wires    = ast.ValueDict()
-        self.driven   = ast.ValueDict()
-        self.ports    = ast.ValueDict()
+        self.rtlil  = rtlil
+        self.wires  = ast.ValueDict()
+        self.driven = ast.ValueDict()
+        self.ports  = ast.ValueDict()
+
+        self.expansions = ast.ValueDict()
 
     def add_driven(self, signal, sync):
         self.driven[signal] = sync
@@ -255,11 +263,26 @@ class _ValueCompilerState:
         wire_curr, wire_next = self.resolve(signal, prefix)
         return wire_curr
 
+    def expand(self, value):
+        return self.expansions.get(value, value)
+
+    @contextmanager
+    def expand_to(self, value, expansion):
+        try:
+            assert value not in self.expansions
+            self.expansions[value] = expansion
+            yield
+        finally:
+            del self.expansions[value]
+
 
 class _ValueCompiler(xfrm.AbstractValueTransformer):
     def __init__(self, state):
         self.s = state
 
+    def on_value(self, value):
+        return super().on_value(self.s.expand(value))
+
     def on_unknown(self, value):
         if value is None:
             return None
@@ -426,7 +449,26 @@ class _RHSValueCompiler(_ValueCompiler):
         return sigspec
 
     def on_Part(self, value):
-        raise NotImplementedError
+        lhs, rhs = value.value, value.offset
+        lhs_bits, lhs_sign = lhs.shape()
+        rhs_bits, rhs_sign = rhs.shape()
+        res_bits, res_sign = value.shape()
+        res = self.s.rtlil.wire(width=res_bits)
+        # Note: Verilog's x[o+:w] construct produces a $shiftx cell, not a $shift cell.
+        # However, Migen's semantics defines the out-of-range bits to be zero, so it is correct
+        # to use a $shift cell here instead, even though it produces less idiomatic Verilog.
+        self.s.rtlil.cell("$shift", ports={
+            "\\A": self(lhs),
+            "\\B": self(rhs),
+            "\\Y": res,
+        }, params={
+            "A_SIGNED": lhs_sign,
+            "A_WIDTH": lhs_bits,
+            "B_SIGNED": rhs_sign,
+            "B_WIDTH": rhs_bits,
+            "Y_WIDTH": res_bits,
+        }, src=src(value.src_loc))
+        return res
 
     def on_Repl(self, value):
         return "{{ {} }}".format(" ".join(self(value.value) for _ in range(value.count)))
@@ -453,7 +495,11 @@ class _LHSValueCompiler(_ValueCompiler):
         return self(value)
 
     def on_Part(self, value):
-        raise NotImplementedError
+        offset = self.s.expand(value.offset)
+        if isinstance(offset, ast.Const):
+            return self(ast.Slice(value.value, offset.value, offset.value + value.width))
+        else:
+            raise LegalizeValue(value.offset, range(0, (1 << len(value.offset) - 1)))
 
     def on_Repl(self, value):
         raise TypeError # :nocov:
@@ -463,10 +509,22 @@ class _LHSValueCompiler(_ValueCompiler):
 
 
 class _StatementCompiler(xfrm.AbstractStatementTransformer):
-    def __init__(self, rhs_compiler, lhs_compiler):
+    def __init__(self, state, rhs_compiler, lhs_compiler):
+        self.state        = state
         self.rhs_compiler = rhs_compiler
         self.lhs_compiler = lhs_compiler
 
+        self._case = None
+
+    @contextmanager
+    def case(self, switch, value):
+        try:
+            old_case = self._case
+            with switch.case(value) as self._case:
+                yield
+        finally:
+            self._case = old_case
+
     def on_Assign(self, stmt):
         if isinstance(stmt, ast.Assign):
             lhs_bits, lhs_sign = stmt.lhs.shape()
@@ -477,15 +535,27 @@ class _StatementCompiler(xfrm.AbstractStatementTransformer):
                 # In RTLIL, LHS and RHS of assignment must have exactly same width.
                 rhs_sigspec = self.rhs_compiler.match_shape(
                     stmt.rhs, lhs_bits, rhs_sign)
-            self.case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec)
+            self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec)
 
     def on_Switch(self, stmt):
-        with self.case.switch(self.rhs_compiler(stmt.test)) as switch:
+        with self._case.switch(self.rhs_compiler(stmt.test)) as switch:
             for value, stmts in stmt.cases.items():
-                old_case = self.case
-                with switch.case(value) as self.case:
+                with self.case(switch, value):
                     self.on_statements(stmts)
-                self.case = old_case
+
+    def on_statement(self, stmt):
+        try:
+            super().on_statement(stmt)
+        except LegalizeValue as legalize:
+            with self._case.switch(self.rhs_compiler(legalize.value)) as switch:
+                bits, sign = legalize.value.shape()
+                tests = ["{:0{}b}".format(v, bits) for v in legalize.branches]
+                tests[-1] = "-" * bits
+                for branch, test in zip(legalize.branches, tests):
+                    with self.case(switch, test):
+                        branch_value = ast.Const(branch, (bits, sign))
+                        with self.state.expand_to(legalize.value, branch_value):
+                            super().on_statement(stmt)
 
     def on_statements(self, stmts):
         for stmt in stmts:
@@ -497,7 +567,7 @@ def convert_fragment(builder, fragment, name, top):
         compiler_state = _ValueCompilerState(module)
         rhs_compiler   = _RHSValueCompiler(compiler_state)
         lhs_compiler   = _LHSValueCompiler(compiler_state)
-        stmt_compiler  = _StatementCompiler(rhs_compiler, lhs_compiler)
+        stmt_compiler  = _StatementCompiler(compiler_state, rhs_compiler, lhs_compiler)
 
         # Register all signals driven in the current fragment. This must be done first, as it
         # affects further codegen; e.g. whether sig$next signals will be generated and used.
@@ -541,7 +611,7 @@ def convert_fragment(builder, fragment, name, top):
                     case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
 
                 # Convert statements into decision trees.
-                stmt_compiler.case = case
+                stmt_compiler._case = case
                 stmt_compiler(fragment.statements)
 
             # For every signal in the sync domain, assign \sig's initial value (which will end up
index a2d3adf705c6c42208a631efb1ac97132c0512bf..049a68cc84f9b4b7da4732ca092945bea8036939 100644 (file)
@@ -890,13 +890,21 @@ class ValueKey:
 
     def __hash__(self):
         if isinstance(self.value, Const):
-            return hash(self.value)
+            return hash(self.value.value)
         elif isinstance(self.value, Signal):
             return hash(id(self.value))
+        elif isinstance(self.value, Operator):
+            return hash((self.value.op, tuple(ValueKey(o) for o in self.value.operands)))
         elif isinstance(self.value, Slice):
             return hash((ValueKey(self.value.value), self.value.start, self.value.end))
+        elif isinstance(self.value, Part):
+            return hash((ValueKey(self.value.value), ValueKey(self.value.offset),
+                         self.value.width))
+        elif isinstance(self.value, Cat):
+            return hash(tuple(ValueKey(o) for o in self.value.operands))
         else: # :nocov:
-            raise TypeError("Object '{!r}' cannot be used as a key in value collections")
+            raise TypeError("Object '{!r}' cannot be used as a key in value collections"
+                            .format(self.value))
 
     def __eq__(self, other):
         if not isinstance(other, ValueKey):
@@ -905,15 +913,28 @@ class ValueKey:
             return False
 
         if isinstance(self.value, Const):
-            return self.value == other.value
+            return self.value.value == other.value.value
         elif isinstance(self.value, Signal):
             return id(self.value) == id(other.value)
+        elif isinstance(self.value, Operator):
+            return (self.value.op == other.value.op and
+                    len(self.value.operands) == len(other.value.operands) and
+                    all(ValueKey(a) == ValueKey(b)
+                        for a, b in zip(self.value.operands, other.value.operands)))
         elif isinstance(self.value, Slice):
             return (ValueKey(self.value.value) == ValueKey(other.value.value) and
                     self.value.start == other.value.start and
                     self.value.end == other.value.end)
+        elif isinstance(self.value, Part):
+            return (ValueKey(self.value.value) == ValueKey(other.value.value) and
+                    ValueKey(self.value.offset) == ValueKey(other.value.offset) and
+                    self.value.width == other.value.width)
+        elif isinstance(self.value, Cat):
+            return all(ValueKey(a) == ValueKey(b)
+                        for a, b in zip(self.value.operands, other.value.operands))
         else: # :nocov:
-            raise TypeError("Object '{!r}' cannot be used as a key in value collections")
+            raise TypeError("Object '{!r}' cannot be used as a key in value collections"
+                            .format(self.value))
 
     def __lt__(self, other):
         if not isinstance(other, ValueKey):