Instance support
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Thu, 8 Dec 2011 15:35:32 +0000 (16:35 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Thu, 8 Dec 2011 15:35:32 +0000 (16:35 +0100)
examples/divider_conv.py
examples/lm32_inst.py [new file with mode: 0644]
examples/simple_gpio.py
migen/fhdl/convtools.py
migen/fhdl/structure.py
migen/fhdl/verilog.py

index 988e0b8a65d786c31874878fd0827001a59badf6..6696fdb86480e4dd00b6780b46b667e65609246d 100644 (file)
@@ -45,6 +45,6 @@ class Divider:
                return f.Fragment(comb, sync)
 
 d = Divider(32)
-f = d.GetFragment()
-o = verilog.Convert(f, {d.ready_o, d.quotient_o, d.remainder_o}, {d.start_i, d.dividend_i, d.divisor_i})
+frag = d.GetFragment()
+o = verilog.Convert(frag, {d.ready_o, d.quotient_o, d.remainder_o, d.start_i, d.dividend_i, d.divisor_i})
 print(o)
\ No newline at end of file
diff --git a/examples/lm32_inst.py b/examples/lm32_inst.py
new file mode 100644 (file)
index 0000000..1af1751
--- /dev/null
@@ -0,0 +1,47 @@
+from migen.fhdl import structure as f
+from migen.fhdl import verilog
+
+class LM32:
+       def __init__(self):
+               self.inst = f.Instance("lm32_top",
+                       [("I_ADR_O", f.BV(32)),
+                       ("I_DAT_O", f.BV(32)),
+                       ("I_SEL_O", f.BV(4)),
+                       ("I_CYC_O", f.BV(1)),
+                       ("I_STB_O", f.BV(1)),
+                       ("I_WE_O", f.BV(1)),
+                       ("I_CTI_O", f.BV(3)),
+                       ("I_LOCK_O", f.BV(1)),
+                       ("I_BTE_O", f.BV(1)),
+                       ("D_ADR_O", f.BV(32)),
+                       ("D_DAT_O", f.BV(32)),
+                       ("D_SEL_O", f.BV(4)),
+                       ("D_CYC_O", f.BV(1)),
+                       ("D_STB_O", f.BV(1)),
+                       ("D_WE_O", f.BV(1)),
+                       ("D_CTI_O", f.BV(3)),
+                       ("D_LOCK_O", f.BV(1)),
+                       ("D_BTE_O", f.BV(1))],
+                       [("interrupt", f.BV(32)),
+                       ("ext_break", f.BV(1)),
+                       ("I_DAT_I", f.BV(32)),
+                       ("I_ACK_I", f.BV(1)),
+                       ("I_ERR_I", f.BV(1)),
+                       ("I_RTY_I", f.BV(1)),
+                       ("D_DAT_I", f.BV(32)),
+                       ("D_ACK_I", f.BV(1)),
+                       ("D_ERR_I", f.BV(1)),
+                       ("D_RTY_I", f.BV(1))],
+                       [],
+                       "clk_i",
+                       "rst_i",
+                       "lm32")
+       
+       def GetFragment(self):
+               return f.Fragment(instances=[self.inst])
+
+cpus = [LM32() for i in range(4)]
+frag = f.Fragment()
+for cpu in cpus:
+       frag += cpu.GetFragment()
+print(verilog.Convert(frag, set([cpus[0].inst.ins["interrupt"], cpus[0].inst.outs["I_WE_O"]])))
\ No newline at end of file
index 2215a8164889ef9a8e0278b0c08982b1ba3b7979..29e75ebd0328e2299ed81111d9ab62a411eecb9e 100644 (file)
@@ -21,5 +21,5 @@ bank = csrgen.Bank([oreg, ireg])
 f = bank.GetFragment() + inf
 i = bank.interface
 ofield.dev_r.name = "gpio_out"
-v = verilog.Convert(f, {i.d_o, ofield.dev_r}, {i.a_i, i.we_i, i.d_i, gpio_in})
+v = verilog.Convert(f, {i.d_o, ofield.dev_ri.a_i, i.we_i, i.d_i, gpio_in})
 print(v)
\ No newline at end of file
index 3dce8a889a8294c297d16d0b9ec4ce5a3b1c3c12..f8051749d445799ee647d6855b7059ed0a7c05b1 100644 (file)
@@ -70,9 +70,20 @@ def ListTargets(node):
        elif isinstance(node, Case):
                l = list(map(lambda x: ListTargets(x[1]), node.cases))
                return ListTargets(node.default).union(*l)
+       elif isinstance(node, Fragment):
+               return ListTargets(node.comb) | ListTargets(node.sync)
        else:
                raise TypeError
 
+def ListInstOuts(i):
+       if isinstance(i, Fragment):
+               return ListInstOuts(i.instances)
+       else:
+               l = []
+               for x in i:
+                       l += list(map(lambda x: x[1], list(x.outs.items())))
+               return set(l)
+
 def IsVariable(node):
        if isinstance(node, Signal):
                return node.variable
index 4116c889297bede7873ccc586ab2f4d502296e24..f3ff20443ad4751dbb94be900988b69705ea00ed 100644 (file)
@@ -150,10 +150,32 @@ class Case:
 
 #
 
+class Instance:
+       def __init__(self, of, outs=[], ins=[], parameters=[], clkport="", rstport="", name=""):
+               self.of = of
+               if name:
+                       self.name = name
+               else:
+                       self.name = of
+               self.outs = dict([(x[0], Signal(x[1], self.name + "_" + x[0])) for x in outs])
+               self.ins = dict([(x[0], Signal(x[1], self.name + "_" + x[0])) for x in ins])
+               self.parameters = parameters
+               self.clkport = clkport
+               self.rstport = rstport
+
+       def __hash__(self):
+               return id(self)
+
 class Fragment:
-       def __init__(self, comb=StatementList(), sync=StatementList()):
+       def __init__(self, comb=StatementList(), sync=StatementList(), instances=[]):
                self.comb = _sl(comb)
                self.sync = _sl(sync)
+               self.instances = instances
        
        def __add__(self, other):
-               return Fragment(self.comb.l + other.comb.l, self.sync.l + other.sync.l)
\ No newline at end of file
+               return Fragment(self.comb.l + other.comb.l, self.sync.l + other.sync.l, self.instances + other.instances)
+       def __iadd__(self, other):
+               self.comb.l += other.comb.l
+               self.sync.l += other.sync.l
+               self.instances += other.instances
+               return self
\ No newline at end of file
index 6fac3f7f81f28f921cb5adfaea0f034a74e65a65..979005b170e534c6db8cb7b79ee49c5464ecc96b 100644 (file)
@@ -2,105 +2,149 @@ from .structure import *
 from .convtools import *
 from functools import partial
 
-def Convert(f, outs=set(), ins=set(), name="top", clkname="sys_clk", rstname="sys_rst"):
+def _printsig(ns, s):
+       if s.bv.signed:
+               n = "signed "
+       else:
+               n = ""
+       if s.bv.width > 1:
+               n += "[" + str(s.bv.width-1) + ":0] "
+       n += ns.GetName(s)
+       return n
+
+def _printexpr(ns, node):
+       if isinstance(node, Constant):
+               if node.n >= 0:
+                       return str(node.bv) + str(node.n)
+               else:
+                       return "-" + str(node.bv) + str(-self.n)
+       elif isinstance(node, Signal):
+               return ns.GetName(node)
+       elif isinstance(node, Operator):
+               arity = len(node.operands)
+               if arity == 1:
+                       r = self.op + _printexpr(ns, node.operands[0])
+               elif arity == 2:
+                       r = _printexpr(ns, node.operands[0]) + " " + node.op + " " + _printexpr(ns, node.operands[1])
+               else:
+                       raise TypeError
+               return "(" + r + ")"
+       elif isinstance(node, Slice):
+               if node.start + 1 == node.stop:
+                       sr = "[" + str(node.start) + "]"
+               else:
+                       sr = "[" + str(node.stop-1) + ":" + str(node.start) + "]"
+               return _printexpr(ns, node.value) + sr
+       elif isinstance(node, Cat):
+               l = list(map(partial(_printexpr, ns), node.l))
+               l.reverse()
+               return "{" + ", ".join(l) + "}"
+       else:
+               raise TypeError
+
+def _printnode(ns, level, comb, node):
+       if isinstance(node, Assign):
+               if comb or IsVariable(node.l):
+                       assignment = " = "
+               else:
+                       assignment = " <= "
+               return "\t"*level + _printexpr(ns, node.l) + assignment + _printexpr(ns, node.r) + ";\n"
+       elif isinstance(node, StatementList):
+               return "".join(list(map(partial(_printnode, ns, level, comb), node.l)))
+       elif isinstance(node, If):
+               r = "\t"*level + "if (" + _printexpr(ns, node.cond) + ") begin\n"
+               r += _printnode(ns, level + 1, comb, node.t)
+               if node.f.l:
+                       r += "\t"*level + "end else begin\n"
+                       r += _printnode(ns, level + 1, comb, node.f)
+               r += "\t"*level + "end\n"
+               return r
+       elif isinstance(node, Case):
+               r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
+               for case in node.cases:
+                       r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n"
+                       r += _printnode(ns, level + 2, comb, case[1])
+                       r += "\t"*(level + 1) + "end\n"
+               r += "\t"*level + "endcase\n"
+               return r
+       else:
+               raise TypeError
+
+def _printinstances(ns, i, clk, rst):
+       r = ""
+       for x in i:
+               r += x.of + " "
+               if x.parameters:
+                       r += "#(\n"
+                       firstp = True
+                       for p in x.parameters:
+                               if not firstp:
+                                       r += ",\n"
+                               firstp = False
+                               r += "\t." + p[0] + "("
+                               if isinstance(p[1], int):
+                                       r += str(p[1])
+                               elif isinstance(p[1], basestring):
+                                       r += "\"" + p[1] + "\""
+                               else:
+                                       raise TypeError
+                               r += ")"
+                       r += "\n) "
+               r += ns.GetName(x) + "(\n"
+               ports = list(x.ins.items()) + list(x.outs.items())
+               if x.clkport:
+                       ports.append((x.clkport, clk))
+               if x.rstport:
+                       ports.append((x.rstport, rst))
+               firstp = True
+               for p in ports:
+                       if not firstp:
+                               r += ",\n"
+                       firstp = False
+                       r += "\t." + p[0] + "(" + ns.GetName(p[1]) + ")"
+               if not firstp:
+                       r += "\n"
+               r += ");\n\n"
+       return r
+
+def Convert(f, ios=set(), name="top", clkname="sys_clk", rstname="sys_rst"):
        ns = Namespace()
        
        clks = Signal(name=clkname)
        rsts = Signal(name=rstname)
-       clk = ns.GetName(clks)
-       rst = ns.GetName(rsts)
-       
-       def printsig(s):
-               if s.bv.signed:
-                       n = "signed "
-               else:
-                       n = ""
-               if s.bv.width > 1:
-                       n += "[" + str(s.bv.width-1) + ":0] "
-               n += ns.GetName(s)
-               return n
-       
-       def printexpr(node):
-               if isinstance(node, Constant):
-                       if node.n >= 0:
-                               return str(node.bv) + str(node.n)
-                       else:
-                               return "-" + str(node.bv) + str(-self.n)
-               elif isinstance(node, Signal):
-                       return ns.GetName(node)
-               elif isinstance(node, Operator):
-                       arity = len(node.operands)
-                       if arity == 1:
-                               r = self.op + printexpr(node.operands[0])
-                       elif arity == 2:
-                               r = printexpr(node.operands[0]) + " " + node.op + " " + printexpr(node.operands[1])
-                       else:
-                               raise TypeError
-                       return "(" + r + ")"
-               elif isinstance(node, Slice):
-                       if node.start + 1 == node.stop:
-                               sr = "[" + str(node.start) + "]"
-                       else:
-                               sr = "[" + str(node.stop-1) + ":" + str(node.start) + "]"
-                       return printexpr(node.value) + sr
-               elif isinstance(node, Cat):
-                       l = list(map(printexpr, node.l))
-                       l.reverse()
-                       return "{" + ", ".join(l) + "}"
-               else:
-                       raise TypeError
+
+       sigs = ListSignals(f)
+       targets = ListTargets(f)
+       instouts = ListInstOuts(f)
        
-       def printnode(level, comb, node):
-               if isinstance(node, Assign):
-                       if comb or IsVariable(node.l):
-                               assignment = " = "
-                       else:
-                               assignment = " <= "
-                       return "\t"*level + printexpr(node.l) + assignment + printexpr(node.r) + ";\n"
-               elif isinstance(node, StatementList):
-                       return "".join(list(map(partial(printnode, level, comb), node.l)))
-               elif isinstance(node, If):
-                       r = "\t"*level + "if (" + printexpr(node.cond) + ") begin\n"
-                       r += printnode(level + 1, comb, node.t)
-                       if node.f.l:
-                               r += "\t"*level + "end else begin\n"
-                               r += printnode(level + 1, comb, node.f)
-                       r += "\t"*level + "end\n"
-                       return r
-               elif isinstance(node, Case):
-                       r = "\t"*level + "case (" + printexpr(node.test) + ")\n"
-                       for case in node.cases:
-                               r += "\t"*(level + 1) + printexpr(case[0]) + ": begin\n"
-                               r += printnode(level + 2, comb, case[1])
-                               r += "\t"*(level + 1) + "end\n"
-                       r += "\t"*level + "endcase\n"
-                       return r
-               else:
-                       raise TypeError
-               
-       r = "/* Autogenerated by Migen */\n"
+       r = "/* Machine-generated using Migen */\n"
        r += "module " + name + "(\n"
-       r += "\tinput " + clk + ",\n"
-       r += "\tinput " + rst
-       if ins:
-               r += ",\n\tinput " + ",\n\tinput ".join(map(printsig, ins)) 
-       if outs:
-               r += ",\n\toutput reg " + ",\n\toutput reg ".join(map(printsig, outs)) 
+       r += "\tinput " + ns.GetName(clks) + ",\n"
+       r += "\tinput " + ns.GetName(rsts)
+       for sig in ios:
+               if sig in targets:
+                       r += ",\n\toutput reg " + _printsig(ns, sig)
+               elif sig in instouts:
+                       r += ",\n\toutput " + _printsig(ns, sig)
+               else:
+                       r += ",\n\tinput " + _printsig(ns, sig)
        r += "\n);\n\n"
-       sigs = ListSignals(f).difference(ins, outs)
-       for sig in sigs:
-               r += "reg " + printsig(sig) + ";\n"
+       for sig in sigs - ios:
+               if sig in instouts:
+                       r += "wire " + _printsig(ns, sig) + ";\n"
+               else:
+                       r += "reg " + _printsig(ns, sig) + ";\n"
        r += "\n"
        
        if f.comb.l:
                r += "always @(*) begin\n"
-               r += printnode(1, True, f.comb)
+               r += _printnode(ns, 1, True, f.comb)
                r += "end\n\n"
-       
        if f.sync.l:
-               r += "always @(posedge " + clk + ") begin\n"
-               r += printnode(1, False, InsertReset(rsts, f.sync))
+               r += "always @(posedge " + ns.GetName(clks) + ") begin\n"
+               r += _printnode(ns, 1, False, InsertReset(rsts, f.sync))
                r += "end\n\n"
+       r += _printinstances(ns, f.instances, clks, rsts)
        
        r += "endmodule\n"