Lowering of Special expressions + support ClockSignal/ResetSignal
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Mon, 18 Mar 2013 17:36:50 +0000 (18:36 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Mon, 18 Mar 2013 17:36:50 +0000 (18:36 +0100)
migen/fhdl/module.py
migen/fhdl/specials.py
migen/fhdl/structure.py
migen/fhdl/tools.py
migen/fhdl/verilog.py
migen/fhdl/visit.py
migen/genlib/cdc.py

index 2d52eba934fa9bfab3409d55162a0f2d037b52b7..29750c06bceb14f4edc37903adeb7b940d6ffaf0 100644 (file)
@@ -3,7 +3,7 @@ from itertools import combinations
 
 from migen.fhdl.structure import *
 from migen.fhdl.specials import Special
-from migen.fhdl.tools import flat_iteration
+from migen.fhdl.tools import flat_iteration, rename_clock_domain
 
 class FinalizeError(Exception):
        pass
@@ -158,7 +158,7 @@ class Module:
                        for mod_name, f in subfragments:
                                for cd in f.clock_domains:
                                        if cd.name in needs_renaming:
-                                               f.rename_clock_domain(cd.name, mod_name + "_" + cd.name)
+                                               rename_clock_domain(f, cd.name, mod_name + "_" + cd.name)
                        # sum subfragments
                        for mod_name, f in subfragments:
                                self._fragment += f
index 5c9bac173445ff0c21b0407c5f82b8d6aa12a53b..5da8fd7f95f9540e0daa219f5ec08ba5d038dfba 100644 (file)
@@ -1,14 +1,32 @@
 from migen.fhdl.structure import *
-from migen.fhdl.tools import list_signals, value_bits_sign
+from migen.fhdl.tools import *
 from migen.fhdl.tracer import get_obj_var_name
 from migen.fhdl.verilog import _printexpr as verilog_printexpr
 
 class Special(HUID):
+       def iter_expressions(self):
+               for x in []:
+                       yield x
+
        def rename_clock_domain(self, old, new):
-               pass
+               for obj, attr, direction in self.iter_expressions():
+                       rename_clock_domain_expr(getattr(obj, attr), old, new)
+
+       def list_clock_domains(self):
+               r = set()
+               for obj, attr, direction in self.iter_expressions():
+                       r |= list_clock_domains_expr(getattr(obj, attr))
+               return r
 
-       def get_clock_domains(self):
-               return set()
+       def list_ios(self, ins, outs, inouts):
+               r = set()
+               for obj, attr, direction in self.iter_expressions():
+                       if (direction == SPECIAL_INPUT and ins) \
+                         or (direction == SPECIAL_OUTPUT and outs) \
+                         or (direction == SPECIAL_INOUT and inouts):
+                               signals = list_signals(getattr(obj, attr))
+                               r.update(signals)
+               return r
 
 class Tristate(Special):
        def __init__(self, target, o, oe, i=None):
@@ -18,16 +36,13 @@ class Tristate(Special):
                self.oe = oe
                self.i = i
 
-       def list_ios(self, ins, outs, inouts):
-               r = set()
-               if inouts:
-                       r.update(list_signals(self.target))
-               if ins:
-                       r.update(list_signals(self.o))
-                       r.update(list_signals(self.oe))
-               if outs:
-                       r.update(list_signals(self.i))
-               return r
+       def iter_expressions(self):
+               for attr, target_context in [
+                 ("target", SPECIAL_INOUT),
+                 ("o", SPECIAL_INPUT),
+                 ("oe", SPECIAL_INPUT),
+                 ("i", SPECIAL_OUTPUT)]:
+                       yield self, attr, target_context
 
        @staticmethod
        def emit_verilog(tristate, ns, clock_domains):
@@ -79,40 +94,19 @@ class Instance(Special):
                        self.name = name
                        self.value = value
        
-       class _CR:
-               def __init__(self, name_inst, domain="sys", invert=False):
-                       self.name_inst = name_inst
-                       self.domain = domain
-                       self.invert = invert
-       class ClockPort(_CR):
-               pass
-       class ResetPort(_CR):
-               pass
-       
        def get_io(self, name):
                for item in self.items:
                        if isinstance(item, Instance._IO) and item.name == name:
                                return item.expr
 
-       def rename_clock_domain(self, old, new):
-               for cr in filter(lambda x: isinstance(x, Instance._CR), self.items):
-                       if cr.domain == old:
-                               cr.domain = new
-
-       def get_clock_domains(self):
-               return set(cr.domain 
-                       for cr in filter(lambda x: isinstance(x, Instance._CR), self.items))
-
-       def list_ios(self, ins, outs, inouts):
-               subsets = [list_signals(item.expr) for item in filter(lambda x:
-                       (ins and isinstance(x, Instance.Input))
-                       or (outs and isinstance(x, Instance.Output))
-                       or (inouts and isinstance(x, Instance.InOut)),
-                       self.items)]
-               if subsets:
-                       return set.union(*subsets)
-               else:
-                       return set()
+       def iter_expressions(self):
+               for item in self.items:
+                       if isinstance(item, Instance.Input):
+                               yield item, "expr", SPECIAL_INPUT
+                       elif isinstance(item, Instance.Output):
+                               yield item, "expr", SPECIAL_OUTPUT
+                       elif isinstance(item, Instance.InOut):
+                               yield item, "expr", SPECIAL_INOUT
 
        @staticmethod
        def emit_verilog(instance, ns, clock_domains):
@@ -144,20 +138,10 @@ class Instance(Special):
                        if isinstance(p, Instance._IO):
                                name_inst = p.name
                                name_design = verilog_printexpr(ns, p.expr)[0]
-                       elif isinstance(p, Instance.ClockPort):
-                               name_inst = p.name_inst
-                               name_design = ns.get_name(clock_domains[p.domain].clk)
-                               if p.invert:
-                                       name_design = "~" + name_design
-                       elif isinstance(p, Instance.ResetPort):
-                               name_inst = p.name_inst
-                               name_design = ns.get_name(clock_domains[p.domain].rst)
-                       else:
-                               continue
-                       if not firstp:
-                               r += ",\n"
-                       firstp = False
-                       r += "\t." + name_inst + "(" + name_design + ")"
+                               if not firstp:
+                                       r += ",\n"
+                               firstp = False
+                               r += "\t." + name_inst + "(" + name_design + ")"
                if not firstp:
                        r += "\n"
                r += ");\n\n"
@@ -214,27 +198,26 @@ class Memory(Special):
                self.ports.append(mp)
                return mp
 
+       def iter_expressions(self):
+               for p in self.ports:
+                       for attr, target_context in [
+                         ("adr", SPECIAL_INPUT),
+                         ("we", SPECIAL_INPUT),
+                         ("dat_w", SPECIAL_INPUT),
+                         ("re", SPECIAL_INPUT),
+                         ("dat_r", SPECIAL_OUTPUT)]:
+                               yield p, attr, target_context
+
        def rename_clock_domain(self, old, new):
+               # port expressions are always signals - no need to call Special.rename_clock_domain
                for port in self.ports:
                        if port.clock_domain == old:
                                port.clock_domain = new
 
-       def get_clock_domains(self):
+       def list_clock_domains(self):
+               # port expressions are always signals - no need to call Special.list_clock_domains
                return set(port.clock_domain for port in self.ports)
 
-       def list_ios(self, ins, outs, inouts):
-               s = set()
-               def add(*sigs):
-                       for sig in sigs:
-                               if sig is not None:
-                                       s.add(sig)
-               for p in self.ports:
-                       if ins:
-                               add(p.adr, p.we, p.dat_w, p.re)
-                       if outs:
-                               add(p.dat_r)
-               return s
-
        @staticmethod
        def emit_verilog(memory, ns, clock_domains):
                r = ""
@@ -315,9 +298,6 @@ class SynthesisDirective(Special):
                self.template = template
                self.signals = signals
 
-       def list_ios(self, ins, outs, inouts):
-               return set()
-
        @staticmethod
        def emit_verilog(directive, ns, clock_domains):
                name_dict = dict((k, ns.get_name(sig)) for k, sig in directive.signals.items())
index be3c35514a06906cf3513fb6f4db9f01c8f71341..08747359399f4eb61f67fc4d6c2faf0575e808ea 100644 (file)
@@ -170,6 +170,16 @@ class Signal(Value):
        def __repr__(self):
                return "<Signal " + (self.backtrace[-1][0] or "anonymous") + " at " + hex(id(self)) + ">"
 
+class ClockSignal(Value):
+       def __init__(self, cd="sys"):
+               Value.__init__(self)
+               self.cd = cd
+       
+class ResetSignal(Value):
+       def __init__(self, cd="sys"):
+               Value.__init__(self)
+               self.cd = cd
+
 # statements
 
 class _Assign:
@@ -260,6 +270,8 @@ class _ClockDomainList(list):
                else:
                        return list.__getitem__(self, key)
 
+(SPECIAL_INPUT, SPECIAL_OUTPUT, SPECIAL_INOUT) = range(3)
+
 class Fragment:
        def __init__(self, comb=None, sync=None, specials=None, clock_domains=None, sim=None):
                if comb is None: comb = []
@@ -287,15 +299,3 @@ class Fragment:
                        self.specials | other.specials,
                        self.clock_domains + other.clock_domains,
                        self.sim + other.sim)
-       
-       def rename_clock_domain(self, old, new):
-               self.sync[new] = self.sync[old]
-               del self.sync[old]
-               for special in self.specials:
-                       special.rename_clock_domain(old, new)
-               try:
-                       cd = self.clock_domains[old]
-               except KeyError:
-                       pass
-               else:
-                       cd.rename(new)
index 1c93ef33e3c6fc4cab762627574d33aeec498828..80dda17134d81461586683bf337d3cc2ed2e2835 100644 (file)
@@ -4,6 +4,11 @@ from migen.fhdl.structure import *
 from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy
 from migen.fhdl.visit import NodeVisitor, NodeTransformer
 
+def bitreverse(s):
+       length, signed = value_bits_sign(s)
+       l = [s[i] for i in reversed(range(length))]
+       return Cat(*l)
+
 def flat_iteration(l):
        for element in l:
                if isinstance(element, collections.Iterable):
@@ -64,10 +69,30 @@ def list_special_ios(f, ins, outs, inouts):
                r |= special.list_ios(ins, outs, inouts)
        return r
 
+class _ClockDomainLister(NodeVisitor):
+       def __init__(self):
+               self.clock_domains = set()
+
+       def visit_ClockSignal(self, node):
+               self.clock_domains.add(node.cd)
+
+       def visit_ResetSignal(self, node):
+               self.clock_domains.add(node.cd)
+
+       def visit_clock_domains(self, node):
+               for clockname, statements in node.items():
+                       self.clock_domains.add(clockname)
+                       self.visit(statements)
+
+def list_clock_domains_expr(f):
+       cdl = _ClockDomainLister()
+       cdl.visit(f)
+       return cdl.clock_domains
+
 def list_clock_domains(f):
-       r = set(f.sync.keys())
+       r = list_clock_domains_expr(f)
        for special in f.specials:
-               r |= special.get_clock_domains()
+               r |= special.list_clock_domains()
        for cd in f.clock_domains:
                r.add(cd.name)
        return r
@@ -99,6 +124,8 @@ def value_bits_sign(v):
                return bits_for(v), v < 0
        elif isinstance(v, Signal):
                return v.nbits, v.signed
+       elif isinstance(v, (ClockSignal, ResetSignal)):
+               return 1, False
        elif isinstance(v, _Operator):
                obs = list(map(value_bits_sign, v.operands))
                if v.op == "+" or v.op == "-":
@@ -168,11 +195,14 @@ def value_bits_sign(v):
        else:
                raise TypeError
 
-class _ArrayLowerer(NodeTransformer):
-       def __init__(self):
+# Basics are FHDL structure elements that back-ends are not required to support
+# but can be expressed in terms of other elements (lowered) before conversion.
+class _BasicLowerer(NodeTransformer):
+       def __init__(self, clock_domains):
                self.comb = []
                self.target_context = False
                self.extra_stmts = []
+               self.clock_domains = clock_domains
 
        def visit_Assign(self, node):
                old_target_context, old_extra_stmts = self.target_context, self.extra_stmts
@@ -203,13 +233,58 @@ class _ArrayLowerer(NodeTransformer):
                        self.comb.append(Case(self.visit(node.key), cases).makedefault())
                return array_muxed
 
-def lower_arrays(f):
-       al = _ArrayLowerer()
-       tf = al.visit(f)
-       tf.comb += al.comb
-       return tf
+       def visit_ClockSignal(self, node):
+               return self.clock_domains[node.cd].clk
 
-def bitreverse(s):
-       length, signed = value_bits_sign(s)
-       l = [s[i] for i in reversed(range(length))]
-       return Cat(*l)
+       def visit_ResetSignal(self, node):
+               return self.clock_domains[node.cd].rst
+
+def lower_basics(f):
+       bl = _BasicLowerer(f.clock_domains)
+       f = bl.visit(f)
+       f.comb += bl.comb
+
+       for special in f.specials:
+               for obj, attr, direction in special.iter_expressions():
+                       if direction != SPECIAL_INOUT:
+                               # inouts are only supported by Migen when connected directly to top-level
+                               # in this case, they are Signal and never need lowering
+                               bl.comb = []
+                               bl.target_context = direction != SPECIAL_INPUT
+                               bl.extra_stmts = []
+                               expr = getattr(obj, attr)
+                               expr = bl.visit(expr)
+                               setattr(obj, attr, expr)
+                               f.comb += bl.comb + bl.extra_stmts
+
+       return f
+
+class _ClockDomainRenamer(NodeVisitor):
+       def __init__(self, old, new):
+               self.old = old
+               self.new = new
+
+       def visit_ClockSignal(self, node):
+               if node.cd == self.old:
+                       node.cd = self.new
+
+       def visit_ResetSignal(self, node):
+               if node.cd == self.old:
+                       node.cd = self.new
+
+def rename_clock_domain_expr(f, old, new):
+       cdr = _ClockDomainRenamer(old, new)
+       cdr.visit(f)
+
+def rename_clock_domain(f, old, new):
+       rename_clock_domain_expr(f, old, new)
+       f.sync[new] = f.sync[old]
+       del f.sync[old]
+       for special in f.specials:
+               special.rename_clock_domain(old, new)
+       try:
+               cd = f.clock_domains[old]
+       except KeyError:
+               pass
+       else:
+               cd.rename(new)
index 7d2f5c55a2cc014c5c81a83986813e2823e8717a..0c873ca08327e50263f613ecd8f6cdfa55710f87 100644 (file)
@@ -203,7 +203,7 @@ def _printcomb(f, ns, display_run):
 def _insert_resets(f):
        newsync = dict()
        for k, v in f.sync.items():
-               newsync[k] = insert_reset(f.clock_domains[k].rst, v)
+               newsync[k] = insert_reset(ResetSignal(k), v)
        f.sync = newsync
 
 def _printsync(f, ns):
@@ -263,10 +263,7 @@ def convert(f, ios=None, name="top",
                f = f.get_fragment()
        if ios is None:
                ios = set()
-               
-       f = lower_arrays(f) # this also copies f
-       fs, lowered_specials = _lower_specials(special_overrides, f.specials)
-       f += fs
+
        for cd_name in list_clock_domains(f):
                try:
                        f.clock_domains[cd_name]
@@ -274,7 +271,11 @@ def convert(f, ios=None, name="top",
                        cd = ClockDomain(cd_name)
                        f.clock_domains.append(cd)
                        ios |= {cd.clk, cd.rst}
+       
        _insert_resets(f)
+       f = lower_basics(f)
+       fs, lowered_specials = _lower_specials(special_overrides, f.specials)
+       f += lower_basics(fs)
 
        ns = build_namespace(list_signals(f) \
                | list_special_ios(f, True, True, True) \
index 059daeef0a405a80e897d8b463d7ed5d7b3dae4c..d8b4ff861c408ea263692b1024a52dc4df28a79a 100644 (file)
@@ -9,6 +9,10 @@ class NodeVisitor:
                        self.visit_constant(node)
                elif isinstance(node, Signal):
                        self.visit_Signal(node)
+               elif isinstance(node, ClockSignal):
+                       self.visit_ClockSignal(node)
+               elif isinstance(node, ResetSignal):
+                       self.visit_ResetSignal(node)
                elif isinstance(node, _Operator):
                        self.visit_Operator(node)
                elif isinstance(node, _Slice):
@@ -39,6 +43,12 @@ class NodeVisitor:
        
        def visit_Signal(self, node):
                pass
+
+       def visit_ClockSignal(self, node):
+               pass
+
+       def visit_ResetSignal(self, node):
+               pass
        
        def visit_Operator(self, node):
                for o in node.operands:
@@ -89,7 +99,7 @@ class NodeVisitor:
                pass
 
 # Default methods always copy the node, except for:
-# - Signals
+# - Signals, ClockSignals and ResetSignals
 # - Unknown objects
 # - All fragment fields except comb and sync
 # In those cases, the original node is returned unchanged.
@@ -99,6 +109,10 @@ class NodeTransformer:
                        return self.visit_constant(node)
                elif isinstance(node, Signal):
                        return self.visit_Signal(node)
+               elif isinstance(node, ClockSignal):
+                       return self.visit_ClockSignal(node)
+               elif isinstance(node, ResetSignal):
+                       return self.visit_ResetSignal(node)
                elif isinstance(node, _Operator):
                        return self.visit_Operator(node)
                elif isinstance(node, _Slice):
@@ -132,6 +146,12 @@ class NodeTransformer:
        def visit_Signal(self, node):
                return node
        
+       def visit_ClockSignal(self, node):
+               return node
+
+       def visit_ResetSignal(self, node):
+               return node
+
        def visit_Operator(self, node):
                return _Operator(node.op, [self.visit(o) for o in node.operands])
        
index eac3341b79ffb2f3ceb148faad3e71a0efb566dc..a7c2dd85331754c6e44f5e6438f33b9c13fcec4e 100644 (file)
@@ -30,19 +30,18 @@ class MultiReg(Special):
                self.odomain = odomain
                self.n = n
 
+       def iter_expressions(self):
+               yield self, "i", SPECIAL_INPUT
+               yield self, "o", SPECIAL_OUTPUT
+
        def rename_clock_domain(self, old, new):
+               Special.rename_clock_domain(self, old, new)
                if self.odomain == old:
                        self.odomain = new
 
-       def get_clock_domains(self):
-               return {self.odomain}
-
-       def list_ios(self, ins, outs, inouts):
-               r = set()
-               if ins:
-                       r.update(list_signals(self.i))
-               if outs:
-                       r.update(list_signals(self.o))
+       def list_clock_domains(self):
+               r = Special.list_clock_domains(self)
+               r.add(self.odomain)
                return r
 
        @staticmethod