transform: unroll
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Fri, 12 Oct 2012 11:16:39 +0000 (13:16 +0200)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Fri, 12 Oct 2012 11:16:39 +0000 (13:16 +0200)
examples/basic/multi_accumulator.py [new file with mode: 0644]
migen/transform/__init__.py [new file with mode: 0644]
migen/transform/unroll.py [new file with mode: 0644]

diff --git a/examples/basic/multi_accumulator.py b/examples/basic/multi_accumulator.py
new file mode 100644 (file)
index 0000000..6f7e050
--- /dev/null
@@ -0,0 +1,19 @@
+from migen.fhdl.structure import *
+from migen.transform.unroll import unroll_sync
+from migen.fhdl import verilog
+
+x = Signal(BV(4))
+y = Signal(BV(4))
+acc = Signal(BV(4))
+
+sync = [
+       acc.eq(acc + x + y)
+]
+
+n = 5
+xs = [Signal(BV(4)) for i in range(n)]
+ys = [Signal(BV(4)) for i in range(n)]
+accs = [Signal(BV(4)) for i in range(n)]
+
+sync_u = unroll_sync(sync, {x: xs, y: ys}, {acc: accs})
+print(verilog.convert(Fragment(sync=sync_u), ios=set(xs+ys+accs)))
diff --git a/migen/transform/__init__.py b/migen/transform/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/migen/transform/unroll.py b/migen/transform/unroll.py
new file mode 100644 (file)
index 0000000..69ec5e9
--- /dev/null
@@ -0,0 +1,94 @@
+from migen.fhdl.structure import *
+from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy
+
+# y <= y + a + b
+#
+# unroll_sync(sync, {b: [b1, b2], c: [c1, c2]}, {y: [y1, y2]})
+#
+# ==>
+#
+# v_y1 = y2 + a1 + b1
+# v_y2 = v_y1 + a2 + b2
+# y1 <= v_y1
+# y2 <= v_y2
+
+(_UNDETERMINED, _IN, _OUT) = range(3)
+
+# TODO: arrays
+
+def _replace_if_in(d, s):
+       try:
+               return d[s]
+       except KeyError:
+               return s
+
+def _replace(node, rin, rout, mode=_UNDETERMINED):
+       if isinstance(node, Constant):
+               return node
+       elif isinstance(node, Signal):
+               if mode == _IN:
+                       return _replace_if_in(rin, node)
+               elif mode == _OUT:
+                       return _replace_if_in(rout, node)
+               else:
+                       raise ValueError
+       elif isinstance(node, _Operator):
+               rop = [_replace(op, rin, rout, mode) for op in node.operands]
+               return _Operator(node.op, rop)
+       elif isinstance(node, _Slice):
+               return _Slice(_replace(node.value, rin, rout, mode), node.start, node.stop)
+       elif isinstance(node, Cat):
+               rcomp = [_replace(comp, rin, rout, mode) for comp in node.l]
+               return Cat(*rcomp)
+       elif isinstance(node, Replicate):
+               return Replicate(_replace(node.v, rin, rout, mode), node.n)
+       elif isinstance(node, _Assign):
+               return _Assign(_replace(node.l, rin, rout, _OUT), _replace(node.r, rin, rout, _IN))
+       elif isinstance(node, list):
+               return [_replace(s, rin, rout) for s in node]
+       elif isinstance(node, If):
+               r = If(_replace(node.cond, rin, rout, _IN))
+               r.t = _replace(node.t, rin, rout)
+               r.f = _replace(node.f, rin, rout)
+               return r
+       elif isinstance(node, Case):
+               r = Case(_replace(case.test, rin, rout, _IN))
+               r.cases = [(c[0], _replace(c[1], rin, rout)) for c in node.cases]
+               r.default = _replace(node.default, rin, rout)
+               return r
+       else:
+               raise TypeError
+
+def _list_step_dicts(d):
+       iterdict = dict((k, iter(v)) for k, v in d.items())
+       r = []
+       try:
+               while True:
+                       r.append(dict([(k, next(i)) for k, i in iterdict.items()]))
+       except StopIteration:
+               pass
+       return r
+
+def _variable_for(s, n):
+       sn = s.backtrace[-1][0]
+       if isinstance(sn, str):
+               name = "v" + str(n) + "_" + sn
+       else:
+               name = "v"
+       return Signal(s.bv, name=name, variable=True)
+
+def unroll_sync(sync, inputs, outputs):
+       sd_in = _list_step_dicts(inputs)
+       sd_out = _list_step_dicts(outputs)
+       
+       do_var_old = sd_out[-1]
+       r = []
+       for n, (di, do) in enumerate(zip(sd_in, sd_out)):
+               do_var = dict((k, _variable_for(v, n)) for k, v in do.items())
+               di_plus_do_var_old = di.copy()
+               di_plus_do_var_old.update(do_var_old)
+               r += _replace(sync, di_plus_do_var_old, do_var)
+               r += [v.eq(do_var[k]) for k, v in do.items()]
+               do_var_old = do_var
+       
+       return r