from soc.regfile.regfiles import FastRegs
 from soc.consts import TT
 from soc.config.state import CoreState
+from soc.regfile.util import spr_to_fast
 
 
 def decode_spr_num(spr):
 
 
 class SPRMap(Elaboratable):
-    """SPRMap: maps POWER9 SPR numbers to internal enum values
+    """SPRMap: maps POWER9 SPR numbers to internal enum values, fast and slow
     """
 
     def __init__(self):
         self.spr_i = Signal(10, reset_less=True)
-        self.spr_o = Signal(SPR, reset_less=True)
+        self.spr_o = Data(SPR, name="spr_o")
+        self.fast_o = Data(3, name="fast_o")
 
     def elaborate(self, platform):
         m = Module()
         with m.Switch(self.spr_i):
             for i, x in enumerate(SPR):
                 with m.Case(x.value):
-                    m.d.comb += self.spr_o.eq(i)
+                    m.d.comb += self.spr_o.data.eq(i)
+                    m.d.comb += self.spr_o.ok.eq(1)
+            for x, v in spr_to_fast.items():
+                with m.Case(x.value):
+                    m.d.comb += self.fast_o.data.eq(v)
+                    m.d.comb += self.fast_o.ok.eq(1)
         return m
 
 
             with m.Case(MicrOp.OP_MFSPR):
                 spr = Signal(10, reset_less=True)
                 comb += spr.eq(decode_spr_num(self.dec.SPR))  # from XFX
-                with m.Switch(spr):
-                    # fast SPRs
-                    with m.Case(SPR.CTR.value):
-                        comb += self.fast_out.data.eq(FastRegs.CTR)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.LR.value):
-                        comb += self.fast_out.data.eq(FastRegs.LR)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.TAR.value):
-                        comb += self.fast_out.data.eq(FastRegs.TAR)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.SRR0.value):
-                        comb += self.fast_out.data.eq(FastRegs.SRR0)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.SRR1.value):
-                        comb += self.fast_out.data.eq(FastRegs.SRR1)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.XER.value):
-                        comb += self.fast_out.data.eq(FastRegs.XER)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.DEC.value):
-                        comb += self.fast_out.data.eq(FastRegs.DEC)
-                        comb += self.fast_out.ok.eq(1)
-                    with m.Case(SPR.TB.value):
-                        comb += self.fast_out.data.eq(FastRegs.TB)
-                        comb += self.fast_out.ok.eq(1)
-                    # : map to internal SPR numbers
-                    # XXX TODO: dec and tb not to go through mapping.
-                    with m.Default():
-                        comb += sprmap.spr_i.eq(spr)
-                        comb += self.spr_out.data.eq(sprmap.spr_o)
-                        comb += self.spr_out.ok.eq(1)
+                comb += sprmap.spr_i.eq(spr)
+                comb += self.spr_out.eq(sprmap.spr_o)
+                comb += self.fast_out.eq(sprmap.fast_o)
 
         return m
 
             with m.Case(OutSel.SPR):
                 spr = Signal(10, reset_less=True)
                 comb += spr.eq(decode_spr_num(self.dec.SPR))  # from XFX
-                # TODO MTSPR 1st spr (fast)
+                # MFSPR move to SPRs - needs mapping
                 with m.If(op.internal_op == MicrOp.OP_MTSPR):
-                    with m.Switch(spr):
-                        # fast SPRs
-                        with m.Case(SPR.CTR.value):
-                            comb += self.fast_out.data.eq(FastRegs.CTR)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.LR.value):
-                            comb += self.fast_out.data.eq(FastRegs.LR)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.TAR.value):
-                            comb += self.fast_out.data.eq(FastRegs.TAR)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.SRR0.value):
-                            comb += self.fast_out.data.eq(FastRegs.SRR0)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.SRR1.value):
-                            comb += self.fast_out.data.eq(FastRegs.SRR1)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.XER.value):
-                            comb += self.fast_out.data.eq(FastRegs.XER)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.TB.value):
-                            comb += self.fast_out.data.eq(FastRegs.TB)
-                            comb += self.fast_out.ok.eq(1)
-                        with m.Case(SPR.DEC.value):
-                            comb += self.fast_out.data.eq(FastRegs.DEC)
-                            comb += self.fast_out.ok.eq(1)
-                        # : map to internal SPR numbers
-                        # XXX TODO: dec and tb not to go through mapping.
-                        with m.Default():
-                            comb += sprmap.spr_i.eq(spr)
-                            comb += self.spr_out.data.eq(sprmap.spr_o)
-                            comb += self.spr_out.ok.eq(1)
+                    comb += sprmap.spr_i.eq(spr)
+                    comb += self.spr_out.eq(sprmap.spr_o)
+                    comb += self.fast_out.eq(sprmap.fast_o)
 
         with m.Switch(op.internal_op):
 
 
 from soc.regfile.regfiles import FastRegs
 from soc.decoder.power_enums import SPR, spr_dict
 
+spr_to_fast = { SPR.CTR: FastRegs.CTR,
+                SPR.LR: FastRegs.LR,
+                SPR.TAR: FastRegs.TAR,
+                SPR.SRR0: FastRegs.SRR0,
+                SPR.SRR1: FastRegs.SRR1,
+                SPR.XER: FastRegs.XER,
+                SPR.DEC: FastRegs.DEC,
+                SPR.TB: FastRegs.TB,
+               }
+
+sprstr_to_fast = {}
+fast_to_spr = {}
+for (k, v) in spr_to_fast.items():
+    sprstr_to_fast[k.name] = v
+    fast_to_spr[v] = k
+
 def fast_reg_to_spr(spr_num):
-    if spr_num == FastRegs.CTR:
-        return SPR.CTR.value
-    elif spr_num == FastRegs.LR:
-        return SPR.LR.value
-    elif spr_num == FastRegs.TAR:
-        return SPR.TAR.value
-    elif spr_num == FastRegs.SRR0:
-        return SPR.SRR0.value
-    elif spr_num == FastRegs.SRR1:
-        return SPR.SRR1.value
+    return fast_to_spr[spr_num].value
 
 
 def spr_to_fast_reg(spr_num):
     if not isinstance(spr_num, str):
         spr_num = spr_dict[spr_num].SPR
-    if spr_num == 'CTR':
-        return FastRegs.CTR
-    elif spr_num == 'LR':
-        return FastRegs.LR
-    elif spr_num == 'TAR':
-        return FastRegs.TAR
-    elif spr_num == 'SRR0':
-        return FastRegs.SRR0
-    elif spr_num == 'SRR1':
-        return FastRegs.SRR1
-    return None
+    return sprstr_to_fast[spr_num]