fhdl/tools: use NodeTransformer to lower arrays
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 28 Nov 2012 16:46:15 +0000 (17:46 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 28 Nov 2012 16:46:15 +0000 (17:46 +0100)
migen/fhdl/tools.py

index cb797208f71db9a19fd7fc6085fb7d67146e7ec3..0824e46690ef7e0fa1ecbf19f37af9d36f7b6737 100644 (file)
@@ -1,8 +1,6 @@
-from copy import copy
-
 from migen.fhdl.structure import *
 from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy
-from migen.fhdl.visit import NodeVisitor
+from migen.fhdl.visit import NodeVisitor, NodeTransformer
 
 class _SignalLister(NodeVisitor):
        def __init__(self):
@@ -144,100 +142,32 @@ def value_bv(v):
        else:
                raise TypeError
 
-def _lower_arrays_values(vl):
-       r = []
-       extra_comb = []
-       for v in vl:
-               v2, e = _lower_arrays_value(v)
-               extra_comb += e
-               r.append(v2)
-       return r, extra_comb
-
-def _lower_arrays_value(v):
-       if isinstance(v, Constant):
-               return v, []
-       elif isinstance(v, Signal):
-               return v, []
-       elif isinstance(v, _Operator):
-               op2, e = _lower_arrays_values(v.operands)
-               return _Operator(v.op, op2), e
-       elif isinstance(v, _Slice):
-               v2, e = _lower_arrays_value(v.value)
-               return _Slice(v2, v.start, v.stop), e
-       elif isinstance(v, Cat):
-               l2, e = _lower_arrays_values(v.l)
-               return Cat(*l2), e
-       elif isinstance(v, Replicate):
-               v2, e = _lower_arrays_value(v.v)
-               return Replicate(v2, v.n), e
-       elif isinstance(v, _ArrayProxy):
-               choices2, e = _lower_arrays_values(v.choices)
-               array_muxed = Signal(value_bv(v))
-               cases = [[Constant(n), _Assign(array_muxed, choice)]
-                       for n, choice in enumerate(choices2)]
-               cases[-1][0] = Default()
-               e.append(Case(v.key, *cases))
-               return array_muxed, e
-
-def _lower_arrays_assign(l, r):
-       extra_comb = []
-       if isinstance(l, _ArrayProxy):
-               k, e = _lower_arrays_value(l.key)
-               extra_comb += e
-               cases = []
-               for n, choice in enumerate(l.choices):
-                       assign, e = _lower_arrays_assign(choice, r)
-                       extra_comb += e
-                       cases.append([Constant(n), assign])
+class _ArrayLowerer(NodeTransformer):
+       def __init__(self):
+               self.comb = []
+       
+       def visit_Assign(self, node):
+               if isinstance(node.l, _ArrayProxy):
+                       k = self.visit(node.l.key)
+                       cases = []
+                       for n, choice in enumerate(node.l.choices):
+                               assign = self.visit_Assign(_Assign(choice, node.r))
+                               cases.append([Constant(n), assign])
+                       cases[-1][0] = Default()
+                       return Case(k, *cases)
+               else:
+                       return super().visit_Assign(node)
+       
+       def visit_ArrayProxy(self, node):
+               array_muxed = Signal(value_bv(node))
+               cases = [[Constant(n), _Assign(array_muxed, self.visit(choice))]
+                       for n, choice in enumerate(node.choices)]
                cases[-1][0] = Default()
-               return Case(k, *cases), extra_comb
-       else:
-               return _Assign(l, r), extra_comb
-               
-def _lower_arrays_sl(sl):
-       rs = []
-       extra_comb = []
-       for statement in sl:
-               if isinstance(statement, _Assign):
-                       r, e = _lower_arrays_value(statement.r)
-                       extra_comb += e
-                       r, e = _lower_arrays_assign(statement.l, r)
-                       extra_comb += e
-                       rs.append(r)
-               elif isinstance(statement, If):
-                       cond, e = _lower_arrays_value(statement.cond)
-                       extra_comb += e
-                       t, e = _lower_arrays_sl(statement.t)
-                       extra_comb += e
-                       f, e = _lower_arrays_sl(statement.f)
-                       extra_comb += e
-                       i = If(cond)
-                       i.t = t
-                       i.f = f
-                       rs.append(i)
-               elif isinstance(statement, Case):
-                       test, e = _lower_arrays_value(statement.test)
-                       extra_comb += e
-                       c = Case(test)
-                       for cond, csl in statement.cases:
-                               stmts, e = _lower_arrays_sl(csl)
-                               extra_comb += e
-                               c.cases.append((cond, stmts))
-                       if statement.default is not None:
-                               c.default, e = _lower_arrays_sl(statement.default)
-                               extra_comb += e
-                       rs.append(c)
-               elif statement is not None:
-                       raise TypeError
-       return rs, extra_comb
+               self.comb.append(Case(self.visit(node.key), *cases))
+               return array_muxed
 
 def lower_arrays(f):
-       f = copy(f)
-       f.comb, ec1 = _lower_arrays_sl(f.comb)
-       f.comb += ec1
-       newsync = dict()
-       for k, v in f.sync.items():
-               newsync[k], ec2 = _lower_arrays_sl(v)
-               f.comb += ec2
-       f.sync = newsync
-       return f
+       al = _ArrayLowerer()
+       f2 = al.visit(f)
+       f2.comb += al.comb
+       return f2