compat.genlib.fsm: import/wrap Migen code.
authorwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 12:40:14 +0000 (12:40 +0000)
committerwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 12:41:19 +0000 (12:41 +0000)
nmigen/back/verilog.py
nmigen/compat/__init__.py
nmigen/compat/fhdl/structure.py
nmigen/compat/genlib/fsm.py [new file with mode: 0644]
nmigen/fhdl/xfrm.py

index 8ac5f1fff49195f9fab88964b953a7c89f4a5d38..3eaf07cc1b3c6e42f57f0614815765c364331779 100644 (file)
@@ -30,7 +30,8 @@ proc_clean
 write_verilog
 # Make sure there are no undriven wires in generated RTLIL.
 proc
-select -assert-none w:* i:* %a %d c:* %co* %a %d n:$* %d
+write_ilang x.il
+select -assert-none w:* i:* %a %d o:* %a %ci* %d c:* %co* %a %d n:$* %d
 """.format(il_text))
     if popen.returncode:
         raise YosysError(error.strip())
index 25bd84ac81ebc87a953c640412b2e490d32200a1..8617cce5af82342dd0a1a7a1e66b195ce8e09501 100644 (file)
@@ -8,4 +8,4 @@ from .fhdl.bitcontainer import *
 # from .sim import *
 
 # from .genlib.record import *
-from .genlib.fsm import *
+from .genlib.fsm import *
index 95b8b5ca65d6a7b6662638920d17ae9daea2bdbc..4f89080422b5e305ce86b482562e2e03b6312a3f 100644 (file)
@@ -45,17 +45,41 @@ class If(ast.Switch):
 
 
 class Case(ast.Switch):
-    @deprecated("instead of `Case(test, ...)`, use `with m.Case(test, ...):`")
+    @deprecated("instead of `Case(test, { value: stmts })`, use `with m.Switch(test):` and "
+                "`with m.Case(value): stmts`; instead of `\"default\": stmts`, use "
+                "`with m.Case(): stmts`")
     def __init__(self, test, cases):
         new_cases = []
         for k, v in cases.items():
-            if k == "default":
+            if isinstance(k, (bool, int)):
+                k = Const(k)
+            if (not isinstance(k, Const)
+                    and not (isinstance(k, str) and k == "default")):
+                raise TypeError("Case object is not a Migen constant")
+            if isinstance(k, str) and k == "default":
                 k = "-" * len(ast.Value.wrap(test))
+            else:
+                k = k.value
             new_cases.append((k, v))
         super().__init__(test, OrderedDict(new_cases))
 
+    @deprecated("instead of `Case(...).makedefault()`, use an explicit default case: "
+                "`with m.Case(): ...`")
     def makedefault(self, key=None):
-        raise NotImplementedError
+        if key is None:
+            for choice in self.cases.keys():
+                if (key is None
+                        or (isinstance(choice, str) and choice == "default")
+                        or choice > key):
+                    key = choice
+        if isinstance(key, str) and key == "default":
+            key = "-" * len(self.test)
+        else:
+            key = "{:0{}b}".format(wrap(key).value, len(self.test))
+        stmts = self.cases[key]
+        del self.cases[key]
+        self.cases["-" * len(self.test)] = stmts
+        return self
 
 
 def Array(*args):
diff --git a/nmigen/compat/genlib/fsm.py b/nmigen/compat/genlib/fsm.py
new file mode 100644 (file)
index 0000000..548eb6a
--- /dev/null
@@ -0,0 +1,187 @@
+import warnings
+from collections import OrderedDict
+
+from ...fhdl.xfrm import ValueTransformer, StatementTransformer
+from ...fhdl.ast import *
+from ..fhdl.module import CompatModule, CompatFinalizeError
+from ..fhdl.structure import If, Case
+
+
+__all__ = ["AnonymousState", "NextState", "NextValue", "FSM"]
+
+
+class AnonymousState:
+    pass
+
+
+class NextState(Statement):
+    def __init__(self, state):
+        self.state = state
+
+
+class NextValue(Statement):
+    def __init__(self, target, value):
+        self.target = target
+        self.value = value
+
+
+def _target_eq(a, b):
+    if type(a) != type(b):
+        return False
+    ty = type(a)
+    if ty == Const:
+        return a.value == b.value
+    elif ty == Signal:
+        return a is b
+    elif ty == Cat:
+        return all(_target_eq(x, y) for x, y in zip(a.l, b.l))
+    elif ty == Slice:
+        return (_target_eq(a.value, b.value)
+                    and a.start == b.start
+                    and a.stop == b.stop)
+    elif ty == Part:
+        return (_target_eq(a.value, b.value)
+                    and _target_eq(a.offset == b.offset)
+                    and a.width == b.width)
+    elif ty == ArrayProxy:
+        return (all(_target_eq(x, y) for x, y in zip(a.choices, b.choices))
+                    and _target_eq(a.key, b.key))
+    else:
+        raise ValueError("NextValue cannot be used with target type '{}'"
+                         .format(ty))
+
+
+class _LowerNext(ValueTransformer, StatementTransformer):
+    def __init__(self, next_state_signal, encoding, aliases):
+        self.next_state_signal = next_state_signal
+        self.encoding = encoding
+        self.aliases = aliases
+        # (target, next_value_ce, next_value)
+        self.registers = []
+
+    def _get_register_control(self, target):
+        for x in self.registers:
+            if _target_eq(target, x[0]):
+                return x[1], x[2]
+        raise KeyError
+
+    def on_unknown_statement(self, node):
+        if isinstance(node, NextState):
+            try:
+                actual_state = self.aliases[node.state]
+            except KeyError:
+                actual_state = node.state
+            return self.next_state_signal.eq(self.encoding[actual_state])
+        elif isinstance(node, NextValue):
+            try:
+                next_value_ce, next_value = self._get_register_control(node.target)
+            except KeyError:
+                related = node.target if isinstance(node.target, Signal) else None
+                next_value = Signal(node.target.shape())
+                next_value_ce = Signal()
+                self.registers.append((node.target, next_value_ce, next_value))
+            return next_value.eq(node.value), next_value_ce.eq(1)
+        else:
+            return node
+
+
+class FSM(CompatModule):
+    def __init__(self, reset_state=None):
+        self.actions = OrderedDict()
+        self.state_aliases = dict()
+        self.reset_state = reset_state
+
+        self.before_entering_signals = OrderedDict()
+        self.before_leaving_signals = OrderedDict()
+        self.after_entering_signals = OrderedDict()
+        self.after_leaving_signals = OrderedDict()
+
+    def act(self, state, *statements):
+        if self.finalized:
+            raise CompatFinalizeError
+        if self.reset_state is None:
+            self.reset_state = state
+        if state not in self.actions:
+            self.actions[state] = []
+        self.actions[state] += statements
+
+    def delayed_enter(self, name, target, delay):
+        if self.finalized:
+            raise CompatFinalizeError
+        if delay > 0:
+            state = name
+            for i in range(delay):
+                if i == delay - 1:
+                    next_state = target
+                else:
+                    next_state = AnonymousState()
+                self.act(state, NextState(next_state))
+                state = next_state
+        else:
+            self.state_aliases[name] = target
+
+    def ongoing(self, state):
+        is_ongoing = Signal()
+        self.act(state, is_ongoing.eq(1))
+        return is_ongoing
+
+    def _get_signal(self, d, state):
+        if state not in self.actions:
+            self.actions[state] = []
+        try:
+            return d[state]
+        except KeyError:
+            is_el = Signal()
+            d[state] = is_el
+            return is_el
+
+    def before_entering(self, state):
+        return self._get_signal(self.before_entering_signals, state)
+
+    def before_leaving(self, state):
+        return self._get_signal(self.before_leaving_signals, state)
+
+    def after_entering(self, state):
+        signal = self._get_signal(self.after_entering_signals, state)
+        self.sync += signal.eq(self.before_entering(state))
+        return signal
+
+    def after_leaving(self, state):
+        signal = self._get_signal(self.after_leaving_signals, state)
+        self.sync += signal.eq(self.before_leaving(state))
+        return signal
+
+    def do_finalize(self):
+        nstates = len(self.actions)
+        self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
+        self.decoding = {n: s for s, n in self.encoding.items()}
+
+        self.state = Signal(max=nstates, reset=self.encoding[self.reset_state])
+        self.state._enumeration = self.decoding
+        self.next_state = Signal(max=nstates)
+        self.next_state._enumeration = {n: "{}:{}".format(n, s) for n, s in self.decoding.items()}
+
+        for state, signal in self.before_leaving_signals.items():
+            encoded = self.encoding[state]
+            self.comb += signal.eq((self.state == encoded) & ~(self.next_state == encoded))
+        if self.reset_state in self.after_entering_signals:
+            self.after_entering_signals[self.reset_state].reset = 1
+        for state, signal in self.before_entering_signals.items():
+            encoded = self.encoding[state]
+            self.comb += signal.eq(~(self.state == encoded) & (self.next_state == encoded))
+
+        self._finalize_sync(self._lower_controls())
+
+    def _lower_controls(self):
+        return _LowerNext(self.next_state, self.encoding, self.state_aliases)
+
+    def _finalize_sync(self, ls):
+        cases = dict((self.encoding[k], ls.on_statement(v)) for k, v in self.actions.items() if v)
+        with warnings.catch_warnings():
+            self.comb += [
+                self.next_state.eq(self.state),
+                Case(self.state, cases).makedefault(self.encoding[self.reset_state])
+            ]
+            self.sync += self.state.eq(self.next_state)
+            for register, next_value_ce, next_value in ls.registers:
+                self.sync += If(next_value_ce, register.eq(next_value))
index a2befa4fa9c4f251eb4bff4616c0b70bf6f0cf67..2c74e433cff63226fbb12d34e29c04c96302f058 100644 (file)
@@ -1,5 +1,6 @@
-from collections import OrderedDict
+from collections import OrderedDict, Iterable
 
+from ..tools import flatten
 from .ast import *
 from .ir import *
 
@@ -36,6 +37,9 @@ class ValueTransformer:
     def on_Repl(self, value):
         return Repl(self.on_value(value.value), value.count)
 
+    def on_unknown_value(self, value):
+        raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
+
     def on_value(self, value):
         if isinstance(value, Const):
             new_value = self.on_Const(value)
@@ -56,7 +60,7 @@ class ValueTransformer:
         elif isinstance(value, Repl):
             new_value = self.on_Repl(value)
         else:
-            raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
+            new_value = self.on_unknown_value(value)
         if isinstance(new_value, Value):
             new_value.src_loc = value.src_loc
         return new_value
@@ -73,21 +77,24 @@ class StatementTransformer:
         return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
 
     def on_Switch(self, stmt):
-        cases = OrderedDict((k, self.on_value(v)) for k, v in stmt.cases.items())
+        cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
         return Switch(self.on_value(stmt.test), cases)
 
     def on_statements(self, stmt):
-        return list(flatten(self.on_statement(stmt) for stmt in self.on_statement(stmt)))
+        return list(flatten(self.on_statement(stmt) for stmt in stmt))
+
+    def on_unknown_statement(self, stmt):
+        raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
 
     def on_statement(self, stmt):
         if isinstance(stmt, Assign):
             return self.on_Assign(stmt)
         elif isinstance(stmt, Switch):
             return self.on_Switch(stmt)
-        elif isinstance(stmt, (list, tuple)):
+        elif isinstance(stmt, Iterable):
             return self.on_statements(stmt)
         else:
-            raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
+            return self.on_unknown_statement(stmt)
 
     def __call__(self, value):
         return self.on_statement(value)