Multi-clock design support + new instance API
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Mon, 10 Sep 2012 21:45:02 +0000 (23:45 +0200)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Mon, 10 Sep 2012 21:45:02 +0000 (23:45 +0200)
migen/fhdl/structure.py
migen/fhdl/tools.py
migen/fhdl/verilog.py
migen/fhdl/verilog_mem_behavioral.py
migen/sim/generic.py

index 546066e26803fc20feb9f99230e8aee86ca7b5ff..307bf65726d87391972b1e0acfb5adc5e315dfdf 100644 (file)
@@ -1,6 +1,7 @@
 import math
 import inspect
 import re
+from collections import defaultdict
 
 from migen.fhdl import tracer
 
@@ -240,26 +241,49 @@ class Array(list):
 # extras
 
 class Instance:
-       def __init__(self, of, outs=[], ins=[], inouts=[], parameters=[], clkport="", rstport="", name=""):
+       def __init__(self, of, *items, name=""):
                self.of = of
                if name:
                        self.name_override = name
                else:
                        self.name_override = of
-               def process_io(x):
-                       if isinstance(x[1], Signal):
-                               return x # override
-                       elif isinstance(x[1], BV):
-                               return (x[0], Signal(x[1], x[0]))
+               self.items = items
+       
+       class _IO:
+               def __init__(self, name, signal_or_bv):
+                       self.name = name
+                       if isinstance(signal_or_bv, Signal):
+                               self.signal = signal_or_bv
+                       elif isinstance(signal_or_bv, BV):
+                               self.signal = Signal(signal_or_bv, name)
                        else:
                                raise TypeError
-               self.outs = dict(map(process_io, outs))
-               self.ins = dict(map(process_io, ins))
-               self.inouts = dict(map(process_io, inouts))
-               self.parameters = parameters
-               self.clkport = clkport
-               self.rstport = rstport
-
+       class Input(_IO):
+               pass    
+       class Output(_IO):
+               pass
+       class InOut(_IO):
+               pass
+
+       class Parameter:
+               def __init__(self, name, value):
+                       self.name = name
+                       self.value = value
+       
+       class _CR:
+               def __init__(self, name_inst, domain="sys"):
+                       self.name_inst = name_inst
+                       self.domain = domain
+       class ClockPort(_CR):
+               pass
+       class ResetPort(_CR):
+               pass
+       
+       def get_io(self, name):
+               for item in self.items:
+                       if isinstance(item, Instance._IO) and item.name == name:
+                               return item.signal
+       
        def __hash__(self):
                return id(self)
 
@@ -267,7 +291,8 @@ class Instance:
 
 class MemoryPort:
        def __init__(self, adr, dat_r, we=None, dat_w=None,
-         async_read=False, re=None, we_granularity=0, mode=WRITE_FIRST):
+         async_read=False, re=None, we_granularity=0, mode=WRITE_FIRST,
+         clock_domain="sys"):
                self.adr = adr
                self.dat_r = dat_r
                self.we = we
@@ -276,6 +301,7 @@ class MemoryPort:
                self.re = re
                self.we_granularity = we_granularity
                self.mode = mode
+               self.clock_domain = clock_domain
 
 class Memory:
        def __init__(self, width, depth, *ports, init=None):
@@ -289,24 +315,66 @@ class Memory:
 class Fragment:
        def __init__(self, comb=None, sync=None, instances=None, memories=None, sim=None):
                if comb is None: comb = []
-               if sync is None: sync = []
+               if sync is None: sync = dict()
                if instances is None: instances = []
                if memories is None: memories = []
                if sim is None: sim = []
+               
+               if isinstance(sync, list):
+                       sync = {"sys": sync}
+               
                self.comb = comb
                self.sync = sync
                self.instances = instances
                self.memories = memories
                self.sim = sim
+               
        
        def __add__(self, other):
-               return Fragment(self.comb + other.comb,
-                       self.sync + other.sync,
+               newsync = defaultdict(list)
+               for k, v in self.sync.items():
+                       newsync[k] = v[:]
+               for k, v in other.sync.items():
+                       newsync[k].extend(v)
+               return Fragment(self.comb + other.comb, newsync,
                        self.instances + other.instances,
                        self.memories + other.memories,
                        self.sim + other.sim)
-
+       
+       def rename_clock_domain(self, old, new):
+               self.sync["new"] = self.sync["old"]
+               del self.sync["old"]
+               for inst in self.instances:
+                       for cr in filter(lambda x: isinstance(x, Instance._CR), inst.items):
+                               if cr.domain == old:
+                                       cr.domain = new
+               for mem in self.memories:
+                       for port in mem.ports:
+                               if port.clock_domain == old:
+                                       port.clock_domain = new
+
+       def get_clock_domains(self):
+               r = set(self.sync.keys())
+               r |= set(cr.domain 
+                       for inst in self.instances
+                       for cr in filter(lambda x: isinstance(x, Instance._CR), inst.items))
+               r |= set(port.clock_domain
+                       for mem in self.memories
+                       for port in mem.ports)
+               return r
+       
        def call_sim(self, simulator):
                for s in self.sim:
                        if simulator.cycle_counter >= 0 or (hasattr(s, "initialize") and s.initialize):
                                s(simulator)
+
+class ClockDomain:
+       def __init__(self, n1, n2=None):
+               if n2 is None:
+                       n_clk = n1 + "_clk"
+                       n_rst = n1 + "_rst"
+               else:
+                       n_clk = n1
+                       n_rst = n2
+               self.clk = Signal(name_override=n_clk)
+               self.rst = Signal(name_override=n_rst)
index a753d85e68d1630d5bce5ac768c715f109cd970b..90ec62e525ba04435fce4b0c5c638289507b8c8a 100644 (file)
@@ -31,7 +31,10 @@ def list_signals(node):
                l = list(map(lambda x: list_signals(x[1]), node.cases))
                return list_signals(node.test).union(*l).union(list_signals(node.default))
        elif isinstance(node, Fragment):
-               return list_signals(node.comb) | list_signals(node.sync)
+               l = list_signals(node.comb)
+               for k, v in node.sync.items():
+                       l |= list_signals(v)
+               return l
        else:
                raise TypeError
 
@@ -56,7 +59,10 @@ def list_targets(node):
                l = list(map(lambda x: list_targets(x[1]), node.cases))
                return list_targets(node.default).union(*l)
        elif isinstance(node, Fragment):
-               return list_targets(node.comb) | list_targets(node.sync)
+               l = list_targets(node.comb)
+               for k, v in node.sync.items():
+                       l |= list_targets(v)
+               return l
        else:
                raise TypeError
 
@@ -78,16 +84,17 @@ def group_by_targets(sl):
 def list_inst_ios(i, ins, outs, inouts):
        if isinstance(i, Fragment):
                return list_inst_ios(i.instances, ins, outs, inouts)
+       elif isinstance(i, list):
+               if i:
+                       return set.union(*(list_inst_ios(e, ins, outs, inouts) for e in i))
+               else:
+                       return set()
        else:
-               l = []
-               for x in i:
-                       if ins:
-                               l += x.ins.values()
-                       if outs:
-                               l += x.outs.values()
-                       if inouts:
-                               l += x.inouts.values()
-               return set(l)
+               return set(item.signal for item in filter(lambda x:
+                       (ins and isinstance(x, Instance.Input))
+                       or (outs and isinstance(x, Instance.Output))
+                       or (inouts and isinstance(x, Instance.InOut)),
+                       i.items))
 
 def list_mem_ios(m, ins, outs):
        if isinstance(m, Fragment):
@@ -254,6 +261,10 @@ def _lower_arrays_sl(sl):
 def lower_arrays(f):
        f = copy(f)
        f.comb, ec1 = _lower_arrays_sl(f.comb)
-       f.sync, ec2 = _lower_arrays_sl(f.sync)
-       f.comb += ec1 + ec2
+       f.comb += ec1
+       newsync = dict()
+       for k, v in f.sync.items():
+               newsync[k], ec2 = _lower_arrays_sl(v)
+               f.comb += ec2
+       f.sync = newsync
        return f
index c4bfbe154d976ff9b6d543dd057221f27a56e2df..034ac2c6e9a8235815f9b7216903f9b688a0d870 100644 (file)
@@ -169,57 +169,64 @@ def _printcomb(f, ns, display_run):
        r += "\n"
        return r
 
-def _printsync(f, ns, clk, rst):
+def _printsync(f, ns, clock_domains):
        r = ""
-       if f.sync:
-               r += "always @(posedge " + ns.get_name(clk) + ") begin\n"
-               r += _printnode(ns, _AT_SIGNAL, 1, insert_reset(rst, f.sync))
+       for k, v in f.sync.items():
+               r += "always @(posedge " + ns.get_name(clock_domains[k].clk) + ") begin\n"
+               r += _printnode(ns, _AT_SIGNAL, 1, insert_reset(clock_domains[k].rst, v))
                r += "end\n\n"
        return r
 
-def _printinstances(f, ns, clk, rst):
+def _printinstances(f, ns, clock_domains):
        r = ""
        for x in f.instances:
+               parameters = list(filter(lambda i: isinstance(i, Instance.Parameter), x.items))
                r += x.of + " "
-               if x.parameters:
+               if parameters:
                        r += "#(\n"
                        firstp = True
-                       for p in x.parameters:
+                       for p in parameters:
                                if not firstp:
                                        r += ",\n"
                                firstp = False
-                               r += "\t." + p[0] + "("
-                               if isinstance(p[1], int) or isinstance(p[1], float) or isinstance(p[1], Constant):
-                                       r += str(p[1])
-                               elif isinstance(p[1], str):
-                                       r += "\"" + p[1] + "\""
+                               r += "\t." + p.name + "("
+                               if isinstance(p.value, int) or isinstance(p.value, float) or isinstance(p.value, Constant):
+                                       r += str(p.value)
+                               elif isinstance(p.value, str):
+                                       r += "\"" + p.value + "\""
                                else:
                                        raise TypeError
                                r += ")"
                        r += "\n) "
                r += ns.get_name(x) 
-               if x.parameters: r += " "
+               if parameters: r += " "
                r += "(\n"
-               ports = list(x.ins.items()) + list(x.outs.items()) + list(x.inouts.items())
-               if x.clkport:
-                       ports.append((x.clkport, clk))
-               if x.rstport:
-                       ports.append((x.rstport, rst))
                firstp = True
-               for p in ports:
+               for p in x.items:
+                       if isinstance(p, Instance._IO):
+                               name_inst = p.name
+                               name_design = ns.get_name(p.signal)
+                       elif isinstance(p, Instance.ClockPort):
+                               name_inst = p.name_inst
+                               name_design = ns.get_name(clock_domains[p.domain].clk)
+                       elif isinstance(p, Instance.ResetPort):
+                               name_inst = p.name_inst
+                               name_design = ns.get_name(clock_domains[p.domain].rst)
+                       else:
+                               continue
                        if not firstp:
                                r += ",\n"
                        firstp = False
-                       r += "\t." + p[0] + "(" + ns.get_name(p[1]) + ")"
+                       r += "\t." + name_inst + "(" + name_design + ")"
                if not firstp:
                        r += "\n"
                r += ");\n\n"
        return r
 
-def _printmemories(f, ns, handler, clk):
+def _printmemories(f, ns, handler, clock_domains):
        r = ""
        for memory in f.memories:
-               r += handler(memory, ns, clk)
+               r += handler(memory, ns, clock_domains)
        return r
 
 def _printinit(f, ios, ns):
@@ -237,16 +244,17 @@ def _printinit(f, ios, ns):
        return r
 
 def convert(f, ios=set(), name="top",
-  clk_signal=None, rst_signal=None,
+  clock_domains=None,
   return_ns=False,
   memory_handler=verilog_mem_behavioral.handler,
   display_run=False):
-       if clk_signal is None:
-               clk_signal = Signal(name_override="sys_clk")
-               ios.add(clk_signal)
-       if rst_signal is None:
-               rst_signal = Signal(name_override="sys_rst")
-               ios.add(rst_signal)
+       if clock_domains is None:
+               clock_domains = dict()
+               for d in f.get_clock_domains():
+                       cd = ClockDomain(d)
+                       clock_domains[d] = cd
+                       ios.add(cd.clk)
+                       ios.add(cd.rst)
                
        f = lower_arrays(f)
 
@@ -258,9 +266,9 @@ def convert(f, ios=set(), name="top",
        r = "/* Machine-generated using Migen */\n"
        r += _printheader(f, ios, name, ns)
        r += _printcomb(f, ns, display_run)
-       r += _printsync(f, ns, clk_signal, rst_signal)
-       r += _printinstances(f, ns, clk_signal, rst_signal)
-       r += _printmemories(f, ns, memory_handler, clk_signal)
+       r += _printsync(f, ns, clock_domains)
+       r += _printinstances(f, ns, clock_domains)
+       r += _printmemories(f, ns, memory_handler, clock_domains)
        r += _printinit(f, ios, ns)
        r += "endmodule\n"
 
index 59fb86f8ea5dfd1528d9c860f3a759d3874ed60d..a9c88e8cc56d8af4b7d85847addb46abf40864c8 100644 (file)
@@ -1,6 +1,6 @@
 from migen.fhdl.structure import *
 
-def handler(memory, ns, clk):
+def handler(memory, ns, clock_domains):
        r = ""
        gn = ns.get_name
        adrbits = bits_for(memory.depth-1)
@@ -24,8 +24,8 @@ def handler(memory, ns, clk):
                                        + gn(data_reg) + ";\n"
                                data_regs[id(port)] = data_reg
 
-       r += "always @(posedge " + gn(clk) + ") begin\n"
        for port in memory.ports:
+               r += "always @(posedge " + gn(clock_domains[port.clock_domain].clk) + ") begin\n"
                if port.we is not None:
                        if port.we_granularity:
                                n = memory.width//port.we_granularity
@@ -53,7 +53,7 @@ def handler(memory, ns, clk):
                else:
                        r += "\tif (" + gn(port.re) + ")\n"
                        r += "\t" + rd.replace("\n\t", "\n\t\t")
-       r += "end\n\n"
+               r += "end\n\n"
        
        for port in memory.ports:
                if port.async_read:
index e027dc9c8451bbc1dd8da8d0e3492498a074a35f..b4bd0ab0ddd535425f3769e28a6eab3b54b593a2 100644 (file)
@@ -14,9 +14,14 @@ class TopLevel:
                self.top_name = top_name
                self.dut_type = dut_type
                self.dut_name = dut_name
-               self.clk_name = clk_name
-               self.clk_period = clk_period
-               self.rst_name = rst_name
+               
+               self._clk_name = clk_name
+               self._clk_period = clk_period
+               self._rst_name = rst_name
+               
+               cd = ClockDomain(self._clk_name, self._rst_name)
+               self.clock_domains = {"sys": cd}
+               self.ios = {cd.clk, cd.rst}
        
        def get(self, sockaddr):
                template1 = """`timescale 1ns / 1ps
@@ -56,9 +61,9 @@ end
                r = template1.format(top_name=self.top_name,
                        dut_type=self.dut_type,
                        dut_name=self.dut_name,
-                       clk_name=self.clk_name,
-                       hclk_period=str(self.clk_period/2),
-                       rst_name=self.rst_name,
+                       clk_name=self._clk_name,
+                       hclk_period=str(self._clk_period/2),
+                       rst_name=self._rst_name,
                        sockaddr=sockaddr)
                if self.vcd_name is not None:
                        r += template2.format(vcd_name=self.vcd_name,
@@ -78,13 +83,10 @@ class Simulator:
                
                c_top = self.top_level.get(sockaddr)
                
-               clk_signal = Signal(name_override=self.top_level.clk_name)
-               rst_signal = Signal(name_override=self.top_level.rst_name)
                c_fragment, self.namespace = verilog.convert(fragment,
-                       {clk_signal, rst_signal},
+                       ios=self.top_level.ios,
                        name=self.top_level.dut_type,
-                       clk_signal=clk_signal,
-                       rst_signal=rst_signal,
+                       clock_domains=self.top_level.clock_domains,
                        return_ns=True,
                        **vopts)