genlib/fsm: add NextValue to replace reg/reg_next/ce pattern
authorSebastien Bourdeauducq <sb@m-labs.hk>
Tue, 25 Nov 2014 09:16:21 +0000 (17:16 +0800)
committerSebastien Bourdeauducq <sb@m-labs.hk>
Tue, 25 Nov 2014 09:16:21 +0000 (17:16 +0800)
examples/basic/fsm.py
migen/genlib/fsm.py

index 435aaa4b08dfc720a7b26d9db7ba2f8569c43159..769246758ba4021ada394e20acd6a2ba5ee7d23d 100644 (file)
@@ -1,18 +1,29 @@
 from migen.fhdl.std import *
 from migen.fhdl import verilog
-from migen.genlib.fsm import FSM, NextState
+from migen.genlib.fsm import FSM, NextState, NextValue
 
 class Example(Module):
        def __init__(self):
                self.s = Signal()
+               self.counter = Signal(8)
+
                myfsm = FSM()
                self.submodules += myfsm
-               myfsm.act("FOO", self.s.eq(1), NextState("BAR"))
-               myfsm.act("BAR", self.s.eq(0), NextState("FOO"))
+
+               myfsm.act("FOO",
+                       self.s.eq(1),
+                       NextState("BAR")
+               )
+               myfsm.act("BAR",
+                       self.s.eq(0),
+                       NextValue(self.counter, self.counter + 1),
+                       NextState("FOO")
+               )
+
                self.be = myfsm.before_entering("FOO")
                self.ae = myfsm.after_entering("FOO")
                self.bl = myfsm.before_leaving("FOO")
                self.al = myfsm.after_leaving("FOO")
 
 example = Example()
-print(verilog.convert(example, {example.s, example.be, example.ae, example.bl, example.al}))
+print(verilog.convert(example, {example.s, example.counter, example.be, example.ae, example.bl, example.al}))
index 7400262af0922366fcefb04e356d7626d82f2cf6..44a1be395cb2c612555f8fe8b062d041b9e2d6b0 100644 (file)
@@ -3,6 +3,7 @@ from collections import OrderedDict
 from migen.fhdl.std import *
 from migen.fhdl.module import FinalizeError
 from migen.fhdl.visit import NodeTransformer
+from migen.fhdl.bitcontainer import value_bits_sign
 
 class AnonymousState:
        pass
@@ -13,11 +14,18 @@ class NextState:
        def __init__(self, state):
                self.state = state
 
-class _LowerNextState(NodeTransformer):
+class NextValue:
+       def __init__(self, register, value):
+               self.register = register
+               self.value = value
+
+class _LowerNext(NodeTransformer):
        def __init__(self, next_state_signal, encoding, aliases):
                self.next_state_signal = next_state_signal
                self.encoding = encoding
                self.aliases = aliases
+               # register -> next_value_ce, next_value
+               self.registers = OrderedDict()
 
        def visit_unknown(self, node):
                if isinstance(node, NextState):
@@ -26,6 +34,15 @@ class _LowerNextState(NodeTransformer):
                        except KeyError:
                                actual_state = node.state
                        return self.next_state_signal.eq(self.encoding[actual_state])
+               elif isinstance(node, NextValue):
+                       try:
+                               next_value_ce, next_value = self.registers[node.register]
+                       except KeyError:
+                               related = node.register if isinstance(node.register, Signal) else None
+                               next_value = Signal(bits_sign=value_bits_sign(node.register), related=related)
+                               next_value_ce = Signal(related=related)
+                               self.registers[node.register] = next_value_ce, next_value
+                       return next_value.eq(node.value), next_value_ce.eq(1)
                else:
                        return node
 
@@ -97,18 +114,19 @@ class FSM(Module):
 
        def do_finalize(self):
                nstates = len(self.actions)
-
                self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
                self.state = Signal(max=nstates, reset=self.encoding[self.reset_state])
                self.next_state = Signal(max=nstates)
 
-               lns = _LowerNextState(self.next_state, self.encoding, self.state_aliases)
-               cases = dict((self.encoding[k], lns.visit(v)) for k, v in self.actions.items() if v)
+               ln = _LowerNext(self.next_state, self.encoding, self.state_aliases)
+               cases = dict((self.encoding[k], ln.visit(v)) for k, v in self.actions.items() if v)
                self.comb += [
                        self.next_state.eq(self.state),
                        Case(self.state, cases).makedefault(self.encoding[self.reset_state])
                ]
                self.sync += self.state.eq(self.next_state)
+               for register, (next_value_ce, next_value) in ln.registers.items():
+                       self.sync += If(next_value_ce, register.eq(next_value))
 
                # drive entering/leaving signals
                for state, signal in self.before_leaving_signals.items():