class Case(ast.Switch):
- @deprecated("instead of `Case(test, ...)`, use `with m.Case(test, ...):`")
+ @deprecated("instead of `Case(test, { value: stmts })`, use `with m.Switch(test):` and "
+ "`with m.Case(value): stmts`; instead of `\"default\": stmts`, use "
+ "`with m.Case(): stmts`")
def __init__(self, test, cases):
new_cases = []
for k, v in cases.items():
- if k == "default":
+ if isinstance(k, (bool, int)):
+ k = Const(k)
+ if (not isinstance(k, Const)
+ and not (isinstance(k, str) and k == "default")):
+ raise TypeError("Case object is not a Migen constant")
+ if isinstance(k, str) and k == "default":
k = "-" * len(ast.Value.wrap(test))
+ else:
+ k = k.value
new_cases.append((k, v))
super().__init__(test, OrderedDict(new_cases))
+ @deprecated("instead of `Case(...).makedefault()`, use an explicit default case: "
+ "`with m.Case(): ...`")
def makedefault(self, key=None):
- raise NotImplementedError
+ if key is None:
+ for choice in self.cases.keys():
+ if (key is None
+ or (isinstance(choice, str) and choice == "default")
+ or choice > key):
+ key = choice
+ if isinstance(key, str) and key == "default":
+ key = "-" * len(self.test)
+ else:
+ key = "{:0{}b}".format(wrap(key).value, len(self.test))
+ stmts = self.cases[key]
+ del self.cases[key]
+ self.cases["-" * len(self.test)] = stmts
+ return self
def Array(*args):
--- /dev/null
+import warnings
+from collections import OrderedDict
+
+from ...fhdl.xfrm import ValueTransformer, StatementTransformer
+from ...fhdl.ast import *
+from ..fhdl.module import CompatModule, CompatFinalizeError
+from ..fhdl.structure import If, Case
+
+
+__all__ = ["AnonymousState", "NextState", "NextValue", "FSM"]
+
+
+class AnonymousState:
+ pass
+
+
+class NextState(Statement):
+ def __init__(self, state):
+ self.state = state
+
+
+class NextValue(Statement):
+ def __init__(self, target, value):
+ self.target = target
+ self.value = value
+
+
+def _target_eq(a, b):
+ if type(a) != type(b):
+ return False
+ ty = type(a)
+ if ty == Const:
+ return a.value == b.value
+ elif ty == Signal:
+ return a is b
+ elif ty == Cat:
+ return all(_target_eq(x, y) for x, y in zip(a.l, b.l))
+ elif ty == Slice:
+ return (_target_eq(a.value, b.value)
+ and a.start == b.start
+ and a.stop == b.stop)
+ elif ty == Part:
+ return (_target_eq(a.value, b.value)
+ and _target_eq(a.offset == b.offset)
+ and a.width == b.width)
+ elif ty == ArrayProxy:
+ return (all(_target_eq(x, y) for x, y in zip(a.choices, b.choices))
+ and _target_eq(a.key, b.key))
+ else:
+ raise ValueError("NextValue cannot be used with target type '{}'"
+ .format(ty))
+
+
+class _LowerNext(ValueTransformer, StatementTransformer):
+ def __init__(self, next_state_signal, encoding, aliases):
+ self.next_state_signal = next_state_signal
+ self.encoding = encoding
+ self.aliases = aliases
+ # (target, next_value_ce, next_value)
+ self.registers = []
+
+ def _get_register_control(self, target):
+ for x in self.registers:
+ if _target_eq(target, x[0]):
+ return x[1], x[2]
+ raise KeyError
+
+ def on_unknown_statement(self, node):
+ if isinstance(node, NextState):
+ try:
+ actual_state = self.aliases[node.state]
+ except KeyError:
+ actual_state = node.state
+ return self.next_state_signal.eq(self.encoding[actual_state])
+ elif isinstance(node, NextValue):
+ try:
+ next_value_ce, next_value = self._get_register_control(node.target)
+ except KeyError:
+ related = node.target if isinstance(node.target, Signal) else None
+ next_value = Signal(node.target.shape())
+ next_value_ce = Signal()
+ self.registers.append((node.target, next_value_ce, next_value))
+ return next_value.eq(node.value), next_value_ce.eq(1)
+ else:
+ return node
+
+
+class FSM(CompatModule):
+ def __init__(self, reset_state=None):
+ self.actions = OrderedDict()
+ self.state_aliases = dict()
+ self.reset_state = reset_state
+
+ self.before_entering_signals = OrderedDict()
+ self.before_leaving_signals = OrderedDict()
+ self.after_entering_signals = OrderedDict()
+ self.after_leaving_signals = OrderedDict()
+
+ def act(self, state, *statements):
+ if self.finalized:
+ raise CompatFinalizeError
+ if self.reset_state is None:
+ self.reset_state = state
+ if state not in self.actions:
+ self.actions[state] = []
+ self.actions[state] += statements
+
+ def delayed_enter(self, name, target, delay):
+ if self.finalized:
+ raise CompatFinalizeError
+ if delay > 0:
+ state = name
+ for i in range(delay):
+ if i == delay - 1:
+ next_state = target
+ else:
+ next_state = AnonymousState()
+ self.act(state, NextState(next_state))
+ state = next_state
+ else:
+ self.state_aliases[name] = target
+
+ def ongoing(self, state):
+ is_ongoing = Signal()
+ self.act(state, is_ongoing.eq(1))
+ return is_ongoing
+
+ def _get_signal(self, d, state):
+ if state not in self.actions:
+ self.actions[state] = []
+ try:
+ return d[state]
+ except KeyError:
+ is_el = Signal()
+ d[state] = is_el
+ return is_el
+
+ def before_entering(self, state):
+ return self._get_signal(self.before_entering_signals, state)
+
+ def before_leaving(self, state):
+ return self._get_signal(self.before_leaving_signals, state)
+
+ def after_entering(self, state):
+ signal = self._get_signal(self.after_entering_signals, state)
+ self.sync += signal.eq(self.before_entering(state))
+ return signal
+
+ def after_leaving(self, state):
+ signal = self._get_signal(self.after_leaving_signals, state)
+ self.sync += signal.eq(self.before_leaving(state))
+ return signal
+
+ def do_finalize(self):
+ nstates = len(self.actions)
+ self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
+ self.decoding = {n: s for s, n in self.encoding.items()}
+
+ self.state = Signal(max=nstates, reset=self.encoding[self.reset_state])
+ self.state._enumeration = self.decoding
+ self.next_state = Signal(max=nstates)
+ self.next_state._enumeration = {n: "{}:{}".format(n, s) for n, s in self.decoding.items()}
+
+ for state, signal in self.before_leaving_signals.items():
+ encoded = self.encoding[state]
+ self.comb += signal.eq((self.state == encoded) & ~(self.next_state == encoded))
+ if self.reset_state in self.after_entering_signals:
+ self.after_entering_signals[self.reset_state].reset = 1
+ for state, signal in self.before_entering_signals.items():
+ encoded = self.encoding[state]
+ self.comb += signal.eq(~(self.state == encoded) & (self.next_state == encoded))
+
+ self._finalize_sync(self._lower_controls())
+
+ def _lower_controls(self):
+ return _LowerNext(self.next_state, self.encoding, self.state_aliases)
+
+ def _finalize_sync(self, ls):
+ cases = dict((self.encoding[k], ls.on_statement(v)) for k, v in self.actions.items() if v)
+ with warnings.catch_warnings():
+ self.comb += [
+ self.next_state.eq(self.state),
+ Case(self.state, cases).makedefault(self.encoding[self.reset_state])
+ ]
+ self.sync += self.state.eq(self.next_state)
+ for register, next_value_ce, next_value in ls.registers:
+ self.sync += If(next_value_ce, register.eq(next_value))
-from collections import OrderedDict
+from collections import OrderedDict, Iterable
+from ..tools import flatten
from .ast import *
from .ir import *
def on_Repl(self, value):
return Repl(self.on_value(value.value), value.count)
+ def on_unknown_value(self, value):
+ raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
+
def on_value(self, value):
if isinstance(value, Const):
new_value = self.on_Const(value)
elif isinstance(value, Repl):
new_value = self.on_Repl(value)
else:
- raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
+ new_value = self.on_unknown_value(value)
if isinstance(new_value, Value):
new_value.src_loc = value.src_loc
return new_value
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
def on_Switch(self, stmt):
- cases = OrderedDict((k, self.on_value(v)) for k, v in stmt.cases.items())
+ cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases)
def on_statements(self, stmt):
- return list(flatten(self.on_statement(stmt) for stmt in self.on_statement(stmt)))
+ return list(flatten(self.on_statement(stmt) for stmt in stmt))
+
+ def on_unknown_statement(self, stmt):
+ raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
def on_statement(self, stmt):
if isinstance(stmt, Assign):
return self.on_Assign(stmt)
elif isinstance(stmt, Switch):
return self.on_Switch(stmt)
- elif isinstance(stmt, (list, tuple)):
+ elif isinstance(stmt, Iterable):
return self.on_statements(stmt)
else:
- raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
+ return self.on_unknown_statement(stmt)
def __call__(self, value):
return self.on_statement(value)