FSM: new API
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Tue, 25 Jun 2013 20:17:39 +0000 (22:17 +0200)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Tue, 25 Jun 2013 20:17:39 +0000 (22:17 +0200)
examples/basic/fsm.py
migen/actorlib/misc.py
migen/bus/wishbone2asmi.py
migen/bus/wishbone2lasmi.py
migen/genlib/fsm.py
migen/pytholite/compiler.py
migen/pytholite/fsm.py
migen/pytholite/io.py

index eb7e03aee7a49b8110c39e5f4adc63edfcf7b407..fc8a64240085eca73032ade26d0c16f2dea53215 100644 (file)
@@ -1,14 +1,14 @@
 from migen.fhdl.std import *
 from migen.fhdl import verilog
-from migen.genlib.fsm import FSM
+from migen.genlib.fsm import FSM, NextState
 
 class Example(Module):
        def __init__(self):
                self.s = Signal()
-               myfsm = FSM("FOO", "BAR")
+               myfsm = FSM()
                self.submodules += myfsm
-               myfsm.act(myfsm.FOO, self.s.eq(1), myfsm.next_state(myfsm.BAR))
-               myfsm.act(myfsm.BAR, self.s.eq(0), myfsm.next_state(myfsm.FOO))
+               myfsm.act("FOO", self.s.eq(1), NextState("BAR"))
+               myfsm.act("BAR", self.s.eq(0), NextState("FOO"))
 
 example = Example()
 print(verilog.convert(example, {example.s}))
index b99ad2e589094551bb9df0ca2f4c937e812f4907..23e405cd75c94c3c409d6c6ed02672247587d54e 100644 (file)
@@ -47,18 +47,18 @@ class IntSequence(Module):
                else:
                        self.comb += self.source.payload.value.eq(counter)
                
-               fsm = FSM("IDLE", "ACTIVE")
+               fsm = FSM()
                self.submodules += fsm
-               fsm.act(fsm.IDLE,
+               fsm.act("IDLE",
                        load.eq(1),
                        self.parameters.ack.eq(1),
-                       If(self.parameters.stb, fsm.next_state(fsm.ACTIVE))
+                       If(self.parameters.stb, NextState("ACTIVE"))
                )
-               fsm.act(fsm.ACTIVE,
+               fsm.act("ACTIVE",
                        self.busy.eq(1),
                        self.source.stb.eq(1),
                        If(self.source.ack,
                                ce.eq(1),
-                               If(last, fsm.next_state(fsm.IDLE))
+                               If(last, NextState("IDLE"))
                        )
                )
index 2c9d533d3921fbe6a17425e8c6d0872b888f4566..49902a8c10c61c5e58b8191af646a37fa67750da 100644 (file)
@@ -1,6 +1,6 @@
 from migen.fhdl.std import *
 from migen.bus import wishbone
-from migen.genlib.fsm import FSM
+from migen.genlib.fsm import FSM, NextState
 from migen.genlib.misc import split, displacer, chooser
 from migen.genlib.record import Record, layout_len
 
@@ -79,60 +79,58 @@ class WB2ASMI:
                write_to_asmi_pre = Signal()
                sync.append(write_to_asmi.eq(write_to_asmi_pre))
                
-               fsm = FSM("IDLE", "TEST_HIT",
-                       "EVICT_ISSUE", "EVICT_WAIT",
-                       "REFILL_WRTAG", "REFILL_ISSUE", "REFILL_WAIT", "REFILL_COMPLETE")
+               fsm = FSM()
                
-               fsm.act(fsm.IDLE,
-                       If(self.wishbone.cyc & self.wishbone.stb, fsm.next_state(fsm.TEST_HIT))
+               fsm.act("IDLE",
+                       If(self.wishbone.cyc & self.wishbone.stb, NextState("TEST_HIT"))
                )
-               fsm.act(fsm.TEST_HIT,
+               fsm.act("TEST_HIT",
                        If(tag_do.tag == adr_tag,
                                self.wishbone.ack.eq(1),
                                If(self.wishbone.we,
                                        tag_di.dirty.eq(1),
                                        tag_port.we.eq(1)
                                ),
-                               fsm.next_state(fsm.IDLE)
+                               NextState("IDLE")
                        ).Else(
                                If(tag_do.dirty,
-                                       fsm.next_state(fsm.EVICT_ISSUE)
+                                       NextState("EVICT_ISSUE")
                                ).Else(
-                                       fsm.next_state(fsm.REFILL_WRTAG)
+                                       NextState("REFILL_WRTAG")
                                )
                        )
                )
                
-               fsm.act(fsm.EVICT_ISSUE,
+               fsm.act("EVICT_ISSUE",
                        self.asmiport.stb.eq(1),
                        self.asmiport.we.eq(1),
-                       If(self.asmiport.ack, fsm.next_state(fsm.EVICT_WAIT))
+                       If(self.asmiport.ack, NextState("EVICT_WAIT"))
                )
-               fsm.act(fsm.EVICT_WAIT,
+               fsm.act("EVICT_WAIT",
                        # Data is actually sampled by the memory controller in the next state.
                        # But since the data memory has one cycle latency, it gets the data
                        # at the address given during this cycle.
                        If(self.asmiport.get_call_expression(),
                                write_to_asmi_pre.eq(1),
-                               fsm.next_state(fsm.REFILL_WRTAG)
+                               NextState("REFILL_WRTAG")
                        )
                )
                
-               fsm.act(fsm.REFILL_WRTAG,
+               fsm.act("REFILL_WRTAG",
                        # Write the tag first to set the ASMI address
                        tag_port.we.eq(1),
-                       fsm.next_state(fsm.REFILL_ISSUE)
+                       NextState("REFILL_ISSUE")
                )
-               fsm.act(fsm.REFILL_ISSUE,
+               fsm.act("REFILL_ISSUE",
                        self.asmiport.stb.eq(1),
-                       If(self.asmiport.ack, fsm.next_state(fsm.REFILL_WAIT))
+                       If(self.asmiport.ack, NextState("REFILL_WAIT"))
                )
-               fsm.act(fsm.REFILL_WAIT,
-                       If(self.asmiport.get_call_expression(), fsm.next_state(fsm.REFILL_COMPLETE))
+               fsm.act("REFILL_WAIT",
+                       If(self.asmiport.get_call_expression(), NextState("REFILL_COMPLETE"))
                )
-               fsm.act(fsm.REFILL_COMPLETE,
+               fsm.act("REFILL_COMPLETE",
                        write_from_asmi.eq(1),
-                       fsm.next_state(fsm.TEST_HIT)
+                       NextState("TEST_HIT")
                )
                
                return Fragment(comb, sync, specials={data_mem, tag_mem, data_port, tag_port}) \
index c9a538661a8c791137b4433fa6a75120dd0cd1d1..4c7881cdd7e6c07c5b2efc493f79fa70e0274838 100644 (file)
@@ -1,6 +1,6 @@
 from migen.fhdl.std import *
 from migen.bus import wishbone
-from migen.genlib.fsm import FSM
+from migen.genlib.fsm import FSM, NextState
 from migen.genlib.misc import split, displacer, chooser
 from migen.genlib.record import Record, layout_len
 
@@ -71,61 +71,58 @@ class WB2LASMI(Module):
                
                # Control FSM
                assert(lasmim.write_latency >= 1 and lasmim.read_latency >= 1)
-               fsm = FSM("IDLE", "TEST_HIT",
-                       "EVICT_REQUEST", "EVICT_WAIT_DATA_ACK", "EVICT_DATA",
-                       "REFILL_WRTAG", "REFILL_REQUEST", "REFILL_WAIT_DATA_ACK", "REFILL_DATA",
-                       delayed_enters=[
-                               ("EVICT_DATAD", "EVICT_DATA", lasmim.write_latency-1),
-                               ("REFILL_DATAD", "REFILL_DATA", lasmim.read_latency-1)
-                       ])
+               fsm = FSM()
                self.submodules += fsm
                
-               fsm.act(fsm.IDLE,
-                       If(self.wishbone.cyc & self.wishbone.stb, fsm.next_state(fsm.TEST_HIT))
+               fsm.delayed_enter("EVICT_DATAD", "EVICT_DATA", lasmim.write_latency-1)
+               fsm.delayed_enter("REFILL_DATAD", "REFILL_DATA", lasmim.read_latency-1)
+
+               fsm.act("IDLE",
+                       If(self.wishbone.cyc & self.wishbone.stb, NextState("TEST_HIT"))
                )
-               fsm.act(fsm.TEST_HIT,
+               fsm.act("TEST_HIT",
                        If(tag_do.tag == adr_tag,
                                self.wishbone.ack.eq(1),
                                If(self.wishbone.we,
                                        tag_di.dirty.eq(1),
                                        tag_port.we.eq(1)
                                ),
-                               fsm.next_state(fsm.IDLE)
+                               NextState("IDLE")
                        ).Else(
                                If(tag_do.dirty,
-                                       fsm.next_state(fsm.EVICT_REQUEST)
+                                       NextState("EVICT_REQUEST")
                                ).Else(
-                                       fsm.next_state(fsm.REFILL_WRTAG)
+                                       NextState("REFILL_WRTAG")
                                )
                        )
                )
                
-               fsm.act(fsm.EVICT_REQUEST,
+               fsm.act("EVICT_REQUEST",
                        lasmim.stb.eq(1),
                        lasmim.we.eq(1),
-                       If(lasmim.req_ack, fsm.next_state(fsm.EVICT_WAIT_DATA_ACK))
+                       If(lasmim.req_ack, NextState("EVICT_WAIT_DATA_ACK"))
                )
-               fsm.act(fsm.EVICT_WAIT_DATA_ACK,
-                       If(lasmim.dat_ack, fsm.next_state(fsm.EVICT_DATAD))
+               fsm.act("EVICT_WAIT_DATA_ACK",
+                       If(lasmim.dat_ack, NextState("EVICT_DATAD"))
                )
-               fsm.act(fsm.EVICT_DATA,
+               fsm.act("EVICT_DATA",
                        write_to_lasmi.eq(1),
-                       fsm.next_state(fsm.REFILL_WRTAG)
+                       NextState("REFILL_WRTAG")
                )
                
-               fsm.act(fsm.REFILL_WRTAG,
+               fsm.act("REFILL_WRTAG",
                        # Write the tag first to set the LASMI address
                        tag_port.we.eq(1),
-                       fsm.next_state(fsm.REFILL_REQUEST)
+                       NextState("REFILL_REQUEST")
                )
-               fsm.act(fsm.REFILL_REQUEST,
+               fsm.act("REFILL_REQUEST",
                        lasmim.stb.eq(1),
-                       If(lasmim.req_ack, fsm.next_state(fsm.REFILL_WAIT_DATA_ACK))
+                       If(lasmim.req_ack, NextState("REFILL_WAIT_DATA_ACK"))
                )
-               fsm.act(fsm.REFILL_WAIT_DATA_ACK,
-                       If(lasmim.dat_ack, fsm.next_state(fsm.REFILL_DATAD))
+               fsm.act("REFILL_WAIT_DATA_ACK",
+                       If(lasmim.dat_ack, NextState("REFILL_DATAD"))
                )
-               fsm.act(fsm.REFILL_DATA,
+               fsm.act("REFILL_DATA",
                        write_from_lasmi.eq(1),
-                       fsm.next_state(fsm.TEST_HIT)
+                       NextState("TEST_HIT")
                )
index 9463f002eef32506e74ec700da1f77852b878a28..7765f1521bfd6fa6cc6f9c663773315a0b0c7441 100644 (file)
@@ -1,41 +1,73 @@
+from collections import OrderedDict
+
 from migen.fhdl.std import *
+from migen.fhdl.module import FinalizeError
+from migen.fhdl.visit import NodeTransformer
 
-class FSM:
-       def __init__(self, *states, delayed_enters=[]):
-               nstates = len(states) + sum([d[2] for d in delayed_enters])
-               
-               self._state = Signal(max=nstates)
-               self._next_state = Signal(max=nstates)
-               for n, state in enumerate(states):
-                       setattr(self, state, n)
-               self.actions = [[] for i in range(len(states))]
+class AnonymousState:
+       pass
+
+# do not use namedtuple here as it inherits tuple
+# and the latter is used elsewhere in FHDL
+class NextState:
+       def __init__(self, state):
+               self.state = state
+
+class _LowerNextState(NodeTransformer):
+       def __init__(self, next_state_signal, encoding, aliases):
+               self.next_state_signal = next_state_signal
+               self.encoding = encoding
+               self.aliases = aliases
                
-               for name, target, delay in delayed_enters:
-                       target_state = getattr(self, target)
-                       if delay:
-                               name_state = len(self.actions)
-                               setattr(self, name, name_state)
-                               for i in range(delay-1):
-                                       self.actions.append([self.next_state(name_state+i+1)])
-                               self.actions.append([self.next_state(target_state)])
-                       else:
-                               # alias
-                               setattr(self, name, target_state)
-       
-       def reset_state(self, state):
-               self._state.reset = state
-       
-       def next_state(self, state):
-               return self._next_state.eq(state)
-       
+       def visit_unknown(self, node):
+               if isinstance(node, NextState):
+                       try:
+                               actual_state = self.aliases[node.state]
+                       except KeyError:
+                               actual_state = node.state
+                       return self.next_state_signal.eq(self.encoding[actual_state])
+               else:
+                       return node
+
+class FSM(Module):
+       def __init__(self):
+               self.actions = OrderedDict()
+               self.state_aliases = dict()
+               self.reset_state = None
+
        def act(self, state, *statements):
+               if self.finalized:
+                       raise FinalizeError
+               if state not in self.actions:
+                       self.actions[state] = []
                self.actions[state] += statements
+
+       def delayed_enter(self, name, target, delay):
+               if self.finalized:
+                       raise FinalizeError
+               if delay:
+                       state = name
+                       for i in range(delay):
+                               if i == delay - 1:
+                                       next_state = target
+                               else:
+                                       next_state = AnonymousState()
+                               self.act(state, NextState(next_state))
+                               state = next_state
+               else:
+                       self.state_aliases[name] = target
        
-       def get_fragment(self):
-               cases = dict((s, a) for s, a in enumerate(self.actions) if a)
-               comb = [
-                       self._next_state.eq(self._state),
-                       Case(self._state, cases)
+       def do_finalize(self):
+               nstates = len(self.actions)
+
+               self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
+               self.state = Signal(max=nstates)
+               self.next_state = Signal(max=nstates)
+
+               lns = _LowerNextState(self.next_state, self.encoding, self.state_aliases)
+               cases = dict((self.encoding[k], lns.visit(v)) for k, v in self.actions.items() if v)
+               self.comb += [
+                       self.next_state.eq(self.state),
+                       Case(self.state, cases)
                ]
-               sync = [self._state.eq(self._next_state)]
-               return Fragment(comb, sync)
+               self.sync += self.state.eq(self.next_state)
index 20f4972461e4846737f757d99f6e87a794f7b9b2..708fdacf8ba60d0c9e4934d01eaa3adb5ad95ccf 100644 (file)
@@ -166,10 +166,10 @@ class _Compiler:
                states_f, exit_states_f = self.visit_block(node.orelse)
                exit_states = exit_states_t + exit_states_f
                
-               test_state_stmt = If(test, AbstractNextState(states_t[0]))
+               test_state_stmt = If(test, id_next_state(states_t[0]))
                test_state = [test_state_stmt]
                if states_f:
-                       test_state_stmt.Else(AbstractNextState(states_f[0]))
+                       test_state_stmt.Else(id_next_state(states_f[0]))
                else:
                        exit_states.append(test_state)
                
@@ -180,9 +180,9 @@ class _Compiler:
                test = self.ec.visit_expr(node.test)
                states_b, exit_states_b = self.visit_block(node.body)
 
-               test_state = [If(test, AbstractNextState(states_b[0]))]
+               test_state = [If(test, id_next_state(states_b[0]))]
                for exit_state in exit_states_b:
-                       exit_state.insert(0, AbstractNextState(test_state))
+                       exit_state.insert(0, id_next_state(test_state))
                
                sa.assemble([test_state] + states_b, [test_state])
        
@@ -199,7 +199,7 @@ class _Compiler:
                        self.symdict[target] = iteration
                        states_b, exit_states_b = self.visit_block(node.body)
                        for exit_state in last_exit_states:
-                               exit_state.insert(0, AbstractNextState(states_b[0]))
+                               exit_state.insert(0, id_next_state(states_b[0]))
                        last_exit_states = exit_states_b
                        states += states_b
                del self.symdict[target]
index 7bc6c4c227f5f62e7c1567098201b796e5c8c42d..4764ff5b319108111e209b0ae31c3662bb9dab2e 100644 (file)
@@ -1,9 +1,7 @@
-from migen.fhdl import visit as fhdl
-from migen.genlib.fsm import FSM
+from migen.genlib.fsm import FSM, NextState
 
-class AbstractNextState:
-       def __init__(self, target_state):
-               self.target_state = target_state
+def id_next_state(l):
+       return NextState(id(l))
 
 # entry state is first state returned
 class StateAssembler:
@@ -14,37 +12,14 @@ class StateAssembler:
        def assemble(self, n_states, n_exit_states):
                self.states += n_states
                for exit_state in self.exit_states:
-                       exit_state.insert(0, AbstractNextState(n_states[0]))
+                       exit_state.insert(0, id_next_state(n_states[0]))
                self.exit_states = n_exit_states
        
        def ret(self):
                return self.states, self.exit_states
 
-# like list.index, but using "is" instead of comparison
-def _index_is(l, x):
-       for i, e in enumerate(l):
-               if e is x:
-                       return i
-
-class _LowerAbstractNextState(fhdl.NodeTransformer):
-       def __init__(self, fsm, states, stnames):
-               self.fsm = fsm
-               self.states = states
-               self.stnames = stnames
-               
-       def visit_unknown(self, node):
-               if isinstance(node, AbstractNextState):
-                       index = _index_is(self.states, node.target_state)
-                       estate = getattr(self.fsm, self.stnames[index])
-                       return self.fsm.next_state(estate)
-               else:
-                       return node
-
 def implement_fsm(states):
-       stnames = ["S" + str(i) for i in range(len(states))]
-       fsm = FSM(*stnames)
-       lans = _LowerAbstractNextState(fsm, states, stnames)
-       for i, state in enumerate(states):
-               actions = lans.visit(state)
-               fsm.act(getattr(fsm, stnames[i]), *actions)
+       fsm = FSM()
+       for state in states:
+               fsm.act(id(state), state)
        return fsm
index cda649e9b0bc1e78e4d2b90600213f49e80dd154..2efdc77f22be9a177dd3b37b6c3355f5e6e46b32 100644 (file)
@@ -50,7 +50,7 @@ def _gen_df_io(compiler, modelname, to_model, from_model):
                        state += [reg.load(cexpr) for reg in target_regs]
                state += [
                        ep.ack.eq(1),
-                       If(~ep.stb, AbstractNextState(state))
+                       If(~ep.stb, id_next_state(state))
                ]
                return [state], [state]
        else:
@@ -67,7 +67,7 @@ def _gen_df_io(compiler, modelname, to_model, from_model):
                        state.append(signal.eq(compiler.ec.visit_expr(value)))
                state += [
                        ep.stb.eq(1),
-                       If(~ep.ack, AbstractNextState(state))
+                       If(~ep.ack, id_next_state(state))
                ]
                return [state], [state]
 
@@ -110,7 +110,7 @@ def _gen_wishbone_io(compiler, modelname, model, to_model, from_model, bus):
                for target_regs, expr in from_model:
                        cexpr = ec.visit_expr(expr)
                        state += [reg.load(cexpr) for reg in target_regs]
-       state.append(If(~bus.ack, AbstractNextState(state)))
+       state.append(If(~bus.ack, id_next_state(state)))
        return [state], [state]
 
 def _gen_memory_io(compiler, modelname, model, to_model, from_model, port):
@@ -128,7 +128,7 @@ def _gen_memory_io(compiler, modelname, model, to_model, from_model, port):
                return [s1], [s1]
        else:
                s2 = []
-               s1.append(AbstractNextState(s2))
+               s1.append(id_next_state(s2))
                ec = _BusReadExprCompiler(compiler.symdict, modelname, port.dat_r)
                for target_regs, expr in from_model:
                        cexpr = ec.visit_expr(expr)