Refactor Case
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Thu, 29 Nov 2012 00:11:15 +0000 (01:11 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Thu, 29 Nov 2012 00:11:15 +0000 (01:11 +0100)
migen/actorlib/structuring.py
migen/bank/csrgen.py
migen/corelogic/fsm.py
migen/corelogic/misc.py
migen/corelogic/roundrobin.py
migen/fhdl/structure.py
migen/fhdl/tools.py
migen/fhdl/verilog.py
migen/fhdl/visit.py
migen/pytholite/reg.py

index 3402e2f08af9276772f66bbfb5431bb2b13e7049..bb0ff0bbb36e8b7a38f8c330f2a77745c138bbfb 100644 (file)
@@ -50,10 +50,10 @@ class Unpack(Actor):
                                )
                        )
                ]
-               cases = [(i if i else Default(),
-                       Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten())))
-                       for i in range(self.n)]
-               comb.append(Case(mux, *cases))
+               cases = {}
+               for i in range(self.n):
+                       cases[i] = [Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten()))]
+               comb.append(Case(mux, cases).makedefault())
                return Fragment(comb, sync)
 
 class Pack(Actor):
@@ -69,9 +69,9 @@ class Pack(Actor):
                
                load_part = Signal()
                strobe_all = Signal()
-               cases = [(i,
-                       Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten()))
-                       for i in range(self.n)]
+               cases = {}
+               for i in range(self.n):
+                       cases[i] = [Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten())]
                comb = [
                        self.busy.eq(strobe_all),
                        self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack),
@@ -83,7 +83,7 @@ class Pack(Actor):
                                strobe_all.eq(0)
                        ),
                        If(load_part,
-                               Case(demux, *cases),
+                               Case(demux, cases),
                                If(demux == (self.n - 1),
                                        demux.eq(0),
                                        strobe_all.eq(1)
index ae942a486fbe0c7b1826f8a54f49fb15ef07d568..9b28904d27e6ad789bce3e64bc3870a60462c2a0 100644 (file)
@@ -21,7 +21,7 @@ class Bank:
                nbits = bits_for(len(desc_exp)-1)
                
                # Bus writes
-               bwcases = []
+               bwcases = {}
                for i, reg in enumerate(desc_exp):
                        if isinstance(reg, RegisterRaw):
                                comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
@@ -29,14 +29,14 @@ class Bank:
                                        self.interface.we & \
                                        (self.interface.adr[:nbits] == i)))
                        elif isinstance(reg, RegisterFields):
-                               bwra = [i]
+                               bwra = []
                                offset = 0
                                for field in reg.fields:
                                        if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
                                                bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size]))
                                        offset += field.size
-                               if len(bwra) > 1:
-                                       bwcases.append(bwra)
+                               if bwra:
+                                       bwcases[i] = bwra
                                # commit atomic writes
                                for field in reg.fields:
                                        if isinstance(field, FieldAlias) and field.commit_list:
@@ -45,13 +45,13 @@ class Bank:
                        else:
                                raise TypeError
                if bwcases:
-                       sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], *bwcases)))
+                       sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], bwcases)))
                
                # Bus reads
-               brcases = []
+               brcases = {}
                for i, reg in enumerate(desc_exp):
                        if isinstance(reg, RegisterRaw):
-                               brcases.append([i, self.interface.dat_r.eq(reg.w)])
+                               brcases[i] = [self.interface.dat_r.eq(reg.w)]
                        elif isinstance(reg, RegisterFields):
                                brs = []
                                reg_readable = False
@@ -62,12 +62,12 @@ class Bank:
                                        else:
                                                brs.append(Replicate(0, field.size))
                                if reg_readable:
-                                       brcases.append([i, self.interface.dat_r.eq(Cat(*brs))])
+                                       brcases[i] = [self.interface.dat_r.eq(Cat(*brs))]
                        else:
                                raise TypeError
                if brcases:
                        sync.append(self.interface.dat_r.eq(0))
-                       sync.append(If(sel, Case(self.interface.adr[:nbits], *brcases)))
+                       sync.append(If(sel, Case(self.interface.adr[:nbits], brcases)))
                else:
                        comb.append(self.interface.dat_r.eq(0))
                
index dbd4c02a0c33335a03d2b252b9d210c8e3396650..cddd4b37b3c45611ac7d8547446472325981f2aa 100644 (file)
@@ -33,10 +33,10 @@ class FSM:
                self.actions[state] += statements
        
        def get_fragment(self):
-               cases = [[s] + a for s, a in enumerate(self.actions) if a]
+               cases = dict((s, a) for s, a in enumerate(self.actions) if a)
                comb = [
                        self._next_state.eq(self._state),
-                       Case(self._state, *cases)
+                       Case(self._state, cases)
                ]
                sync = [self._state.eq(self._next_state)]
                return Fragment(comb, sync)
index a4da7dcf2c55078f7af924e25676b746d5536f0a..888941cc2b519fa9e68463ba9cbe5ae206aa0d4e 100644 (file)
@@ -43,15 +43,14 @@ def chooser(signal, shift, output, n=None, reverse=False):
        if n is None:
                n = 2**len(shift)
        w = len(output)
-       cases = []
+       cases = {}
        for i in range(n):
                if reverse:
                        s = n - i - 1
                else:
                        s = i
-               cases.append([i, output.eq(signal[s*w:(s+1)*w])])
-       cases[n-1][0] = Default()
-       return Case(shift, *cases)
+               cases[i] = [output.eq(signal[s*w:(s+1)*w])]
+       return Case(shift, cases).makedefault()
 
 def timeline(trigger, events):
        lastevent = max([e[0] for e in events])
index a0d7c73482372df70ae35297a2fad558dd9ac953..f2221bb51168bc888a0225ab3555eea2f009a594 100644 (file)
@@ -14,7 +14,7 @@ class RoundRobin:
        
        def get_fragment(self):
                if self.n > 1:
-                       cases = []
+                       cases = {}
                        for i in range(self.n):
                                switch = []
                                for j in reversed(range(i+1,i+self.n)):
@@ -30,8 +30,8 @@ class RoundRobin:
                                        case = [If(~self.request[i], *switch)]
                                else:
                                        case = switch
-                               cases.append([i] + case)
-                       statement = Case(self.grant, *cases)
+                               cases[i] = case
+                       statement = Case(self.grant, cases)
                        if self.switch_policy == SP_CE:
                                statement = If(self.ce, statement)
                        return Fragment(sync=[statement])
index 16faebc809a8b0231fd967121c4f25d8e390fefa..906f4ccb9fced33e4c387e3491ffd10078505781 100644 (file)
@@ -189,21 +189,19 @@ def _insert_else(obj, clause):
                o = o.f[0]
        o.f = clause
 
-class Default:
-       pass
-
 class Case:
-       def __init__(self, test, *cases):
+       def __init__(self, test, cases):
                self.test = test
-               self.cases = [(c[0], list(c[1:])) for c in cases if not isinstance(c[0], Default)]
-               self.default = None
-               for c in cases:
-                       if isinstance(c[0], Default):
-                               if self.default is not None:
-                                       raise ValueError
-                               self.default = list(c[1:])
-               if self.default is None:
-                       self.default = []
+               self.cases = cases
+       
+       def makedefault(self, key=None):
+               if key is None:
+                       for choice in self.cases.keys():
+                               if key is None or choice > key:
+                                       key = choice
+               self.cases["default"] = self.cases[key]
+               del self.cases[key]
+               return self
 
 # arrays
 
index b26aa441512e8c0b98c3fd721cb7439b53440eb9..925e278f6ce9b7c5546cd46f35008bf170115ae2 100644 (file)
@@ -151,21 +151,19 @@ class _ArrayLowerer(NodeTransformer):
        def visit_Assign(self, node):
                if isinstance(node.l, _ArrayProxy):
                        k = self.visit(node.l.key)
-                       cases = []
+                       cases = {}
                        for n, choice in enumerate(node.l.choices):
                                assign = self.visit_Assign(_Assign(choice, node.r))
-                               cases.append([n, assign])
-                       cases[-1][0] = Default()
-                       return Case(k, *cases)
+                               cases[n] = [assign]
+                       return Case(k, cases).makedefault()
                else:
                        return super().visit_Assign(node)
        
        def visit_ArrayProxy(self, node):
                array_muxed = Signal(value_bv(node))
-               cases = [[n, _Assign(array_muxed, self.visit(choice))]
-                       for n, choice in enumerate(node.choices)]
-               cases[-1][0] = Default()
-               self.comb.append(Case(self.visit(node.key), *cases))
+               cases = dict((n, _Assign(array_muxed, self.visit(choice)))
+                       for n, choice in enumerate(node.choices))
+               self.comb.append(Case(self.visit(node.key), cases).makedefault())
                return array_muxed
 
 def lower_arrays(f):
index ec878c0cad9ec423d0478c79249bffd2395b3e5f..f9e03ff1426f9907213fa8f1c8ea0f5d57481371 100644 (file)
@@ -92,15 +92,16 @@ def _printnode(ns, at, level, node):
                r += "\t"*level + "end\n"
                return r
        elif isinstance(node, Case):
-               if node.cases or node.default:
+               if node.cases:
                        r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
-                       for case in node.cases:
-                               r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n"
-                               r += _printnode(ns, at, level + 2, case[1])
+                       css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
+                       for choice, statements in css:
+                               r += "\t"*(level + 1) + _printexpr(ns, choice) + ": begin\n"
+                               r += _printnode(ns, at, level + 2, statements)
                                r += "\t"*(level + 1) + "end\n"
-                       if node.default:
+                       if "default" in node.cases:
                                r += "\t"*(level + 1) + "default: begin\n"
-                               r += _printnode(ns, at, level + 2, node.default)
+                               r += _printnode(ns, at, level + 2, node.cases["default"])
                                r += "\t"*(level + 1) + "end\n"
                        r += "\t"*level + "endcase\n"
                        return r
index 44abc77de29e72fb7138b7051d0f6c5721fc3ddb..3c39700267d909e06f13441fac06840f97a9bc07 100644 (file)
@@ -65,9 +65,8 @@ class NodeVisitor:
        
        def visit_Case(self, node):
                self.visit(node.test)
-               for v, statements in node.cases:
+               for v, statements in node.cases.items():
                        self.visit(statements)
-               self.visit(node.default)
        
        def visit_Fragment(self, node):
                self.visit(node.comb)
@@ -155,9 +154,8 @@ class NodeTransformer:
                return r
        
        def visit_Case(self, node):
-               r = Case(self.visit(node.test))
-               r.cases = [(v, self.visit(statements)) for v, statements in node.cases]
-               r.default = self.visit(node.default)
+               cases = dict((v, self.visit(statements)) for v, statements in node.cases.items())
+               r = Case(self.visit(node.test), cases)
                return r
        
        def visit_Fragment(self, node):
index 9a4815299bc19928450557f96d1a6d25314a936a..f54323d560754929fcb1bc41a90aefd6a2b5d58b 100644 (file)
@@ -48,6 +48,6 @@ class ImplRegister:
                        raise FinalizeError
                # do nothing when sel == 0
                items = sorted(self.source_encoding.items(), key=itemgetter(1))
-               cases = [(v, self.storage.eq(self.id_to_source[k])) for k, v in items]
-               sync = [Case(self.sel, *cases)]
+               cases = dict((v, self.storage.eq(self.id_to_source[k])) for k, v in items)
+               sync = [Case(self.sel, cases)]
                return Fragment(sync=sync)