class _ArrayLowerer(NodeTransformer):
def __init__(self):
self.comb = []
-
+ self.target_context = False
+ self.extra_stmts = []
+
def visit_Assign(self, node):
- if isinstance(node.l, _ArrayProxy):
- k = self.visit(node.l.key)
- cases = {}
- for n, choice in enumerate(node.l.choices):
- assign = self.visit_Assign(_Assign(choice, node.r))
- cases[n] = [assign]
- return Case(k, cases).makedefault()
- else:
- return NodeTransformer.visit_Assign(self, node)
+ old_target_context, old_extra_stmts = self.target_context, self.extra_stmts
+ self.extra_stmts = []
+
+ self.target_context = True
+ lhs = self.visit(node.l)
+ self.target_context = False
+ rhs = self.visit(node.r)
+ r = _Assign(lhs, rhs)
+ if self.extra_stmts:
+ r = [r] + self.extra_stmts
+
+ self.target_context, self.extra_stmts = old_target_context, old_extra_stmts
+ return r
def visit_ArrayProxy(self, node):
- array_muxed = Signal(value_bits_sign(node))
- cases = dict((n, _Assign(array_muxed, self.visit(choice)))
- for n, choice in enumerate(node.choices))
- self.comb.append(Case(self.visit(node.key), cases).makedefault())
+ array_muxed = Signal(value_bits_sign(node), variable=True)
+ if self.target_context:
+ k = self.visit(node.key)
+ cases = {}
+ for n, choice in enumerate(node.choices):
+ cases[n] = [self.visit_Assign(_Assign(choice, array_muxed))]
+ self.extra_stmts.append(Case(k, cases).makedefault())
+ else:
+ cases = dict((n, _Assign(array_muxed, self.visit(choice)))
+ for n, choice in enumerate(node.choices))
+ self.comb.append(Case(self.visit(node.key), cases).makedefault())
return array_muxed
def lower_arrays(f):
al = _ArrayLowerer()
- f2 = al.visit(f)
- f2.comb += al.comb
- return f2
+ tf = al.visit(f)
+ tf.comb += al.comb
+ return tf
def bitreverse(s):
length, signed = value_bits_sign(s)