pytholite/compiler: support if statements
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Fri, 9 Nov 2012 18:37:52 +0000 (19:37 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Fri, 9 Nov 2012 18:37:52 +0000 (19:37 +0100)
migen/pytholite/compiler.py

index acac97387afd2abdc51b8c43bdb194d1bceb2f3c..a46176a7ca63ec214464799e94bda2402125703d 100644 (file)
@@ -54,6 +54,10 @@ class _Register:
                sync = [Case(self.sel, *cases)]
                return Fragment(sync=sync)
 
+class _AbstractNextState:
+       def __init__(self, target_state):
+               self.target_state = target_state
+
 class _Compiler:
        def __init__(self, symdict, registers):
                self.symdict = symdict
@@ -76,6 +80,15 @@ class _Compiler:
                                op = self.visit_assign(statement)
                                if op:
                                        states.append(op)
+                       elif isinstance(statement, ast.If):
+                               test = self.visit_expr(statement.test)
+                               states_t = self.visit_block(statement.body)
+                               states_f = self.visit_block(statement.orelse)
+                               test_state_stmt = If(test, _AbstractNextState(states_t[0]))
+                               if states_f:
+                                       test_state_stmt.Else(_AbstractNextState(states_f[0]))
+                               states.append([test_state_stmt])
+                               states += states_t + states_f
                        else:
                                raise NotImplementedError
                return states
@@ -163,10 +176,10 @@ class _Compiler:
                        raise NotImplementedError
        
        def visit_expr_compare(self, node):
-               test = visit_expr(node.test)
+               test = self.visit_expr(node.left)
                r = None
                for op, rcomparator in zip(node.ops, node.comparators):
-                       comparator = visit_expr(rcomparator)
+                       comparator = self.visit_expr(rcomparator)
                        if isinstance(op, ast.Eq):
                                comparison = test == comparator
                        elif isinstance(op, ast.NotEq):
@@ -197,15 +210,32 @@ class _Compiler:
        def visit_expr_num(self, node):
                return node.n
 
+# like list.index, but using "is" instead of comparison
+def _index_is(l, x):
+       for i, e in enumerate(l):
+               if e is x:
+                       return i
+
+class _LowerAbstractNextState(fhdl.NodeTransformer):
+       def __init__(self, fsm, states, stnames):
+               self.fsm = fsm
+               self.states = states
+               self.stnames = stnames
+               
+       def visit_unknown(self, node):
+               if isinstance(node, _AbstractNextState):
+                       index = _index_is(self.states, node.target_state)
+                       estate = getattr(self.fsm, self.stnames[index])
+                       return self.fsm.next_state(estate)
+               else:
+                       return node
+
 def _create_fsm(states):
        stnames = ["S" + str(i) for i in range(len(states))]
        fsm = FSM(*stnames)
+       lans = _LowerAbstractNextState(fsm, states, stnames)
        for i, state in enumerate(states):
-               if i == len(states) - 1:
-                       actions = []
-               else:
-                       actions = [fsm.next_state(getattr(fsm, stnames[i+1]))]
-               actions += state
+               actions = lans.visit(state)
                fsm.act(getattr(fsm, stnames[i]), *actions)
        return fsm