based on Anton Blanchard microwatt decode2.vhdl
 
 """
-from nmigen import Module, Elaboratable, Signal, Mux, Const
+from nmigen import Module, Elaboratable, Signal, Mux, Const, Cat, Repl
 from nmigen.cli import rtlil
 
 from soc.decoder.power_decoder import create_pdecode
         self.imm_out = Data(64, "imm_b")
         self.spr_out = Data(10, "spr_b")
 
+    def exts(self, exts_data, width, fullwidth):
+        exts_data = exts_data[0:width]
+        topbit = exts_data[-1]
+        signbits = Repl(topbit, fullwidth-width)
+        return Cat(exts_data, signbits)
+
+
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
                 comb += self.imm_out.data.eq(self.dec.UI[0:-1])
                 comb += self.imm_out.ok.eq(1)
             with m.Case(In2Sel.CONST_SI): # TODO: sign-extend here?
-                comb += self.imm_out.data.eq(self.dec.SI[0:-1])
+                comb += self.imm_out.data.eq(
+                    self.exts(self.dec.SI[0:-1], 16, 64))
                 comb += self.imm_out.ok.eq(1)
             with m.Case(In2Sel.CONST_UI_HI):
                 comb += self.imm_out.data.eq(self.dec.UI[0:-1]<<16)
                 comb += self.imm_out.ok.eq(1)
             with m.Case(In2Sel.CONST_SI_HI): # TODO: sign-extend here?
                 comb += self.imm_out.data.eq(self.dec.SI[0:-1]<<16)
+                comb += self.imm_out.data.eq(
+                    self.exts(self.dec.SI[0:-1] << 16, 32, 64))
                 comb += self.imm_out.ok.eq(1)
             with m.Case(In2Sel.CONST_LI):
                 comb += self.imm_out.data.eq(self.dec.LI[0:-1]<<2)
 
     def __init__(self, num):
         self.num = num
 
+class Checker:
+    def __init__(self):
+        self.imm = 0
+
+    def get_imm(self, in2_sel):
+        if in2_sel == In2Sel.CONST_UI.value:
+            return self.imm & 0xffff
+        if in2_sel == In2Sel.CONST_UI_HI.value:
+            return (self.imm & 0xffff) << 16
+        if in2_sel == In2Sel.CONST_SI.value:
+            sign_bit = 1 << 15
+            return (self.imm & (sign_bit-1)) - (self.imm & sign_bit)
+        if in2_sel == In2Sel.CONST_SI_HI.value:
+            imm = self.imm << 16
+            sign_bit = 1 << 31
+            return (imm & (sign_bit-1)) - (imm & sign_bit)
+        
 
 class RegRegOp:
     def __init__(self):
             assert(rc == 0)
 
 
-class RegImmOp:
+class RegImmOp(Checker):
     def __init__(self):
+        super().__init__()
         self.ops = {
             "addi": InternalOp.OP_ADD,
             "addis": InternalOp.OP_ADD,
 
         imm = yield pdecode2.e.imm_data.data
         in2_sel = yield pdecode2.dec.op.in2_sel
-        if in2_sel in [In2Sel.CONST_SI_HI.value, In2Sel.CONST_UI_HI.value]:
-            assert(imm == (self.imm << 16))
-        else:
-            assert(imm == self.imm)
+        imm_expected = self.get_imm(in2_sel)
+        msg = "imm: got {:x}, expected {:x}".format(imm, imm_expected)
+        assert imm == imm_expected, msg
 
         rc = yield pdecode2.e.rc.data
         if '.' in self.opcodestr:
             assert(rc == 0)
 
 
-class LdStOp:
+class LdStOp(Checker):
     def __init__(self):
+        super().__init__()
         self.ops = {
             "lwz": InternalOp.OP_LOAD,
             "stw": InternalOp.OP_STORE,
         assert(r2sel == self.r2.num)
 
         imm = yield pdecode2.e.imm_data.data
-        assert(imm == self.imm)
+        in2_sel = yield pdecode2.dec.op.in2_sel
+        assert(imm == self.get_imm(in2_sel))
 
         update = yield pdecode2.e.update
         if "u" in self.opcodestr: