# IEEE Floating Point Conversion
 # Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
 
-from nmigen import Module, Signal, Cat
+from nmigen import Module, Signal, Cat, Mux
 from nmigen.cli import main, verilog
 
 from nmutil.pipemodbase import PipeModBase
 from ieee754.fpcommon.postcalc import FPPostCalcData
 from ieee754.fpcommon.msbhigh import FPMSBHigh
 
-from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
+from ieee754.fpcommon.fpbase import FPNumBaseRecord
 
 
 class FPCVTIntToFloatMod(PipeModBase):
               self.out_pspec.width)
         print("a1", self.in_pspec.width)
         z1 = self.o.z
+        a = self.i.a
         print("z1", z1.width, z1.rmw, z1.e_width, z1.e_start, z1.e_end)
 
         me = self.in_pspec.width
-        mz = self.o.z.rmw
+        mz = z1.rmw
         ms = mz - me
         print("ms-me", ms, me, mz)
 
         signed = Signal(reset_less=True)
         comb += signed.eq(self.i.ctx.op[0])
 
-        # copy of mantissa (one less bit if signed)
+        # mantissa (one less bit if signed), and sign
         mantissa = Signal(me, reset_less=True)
+        sign = Signal(reset_less=True)
 
         # detect signed/unsigned.  key case: -ve numbers need inversion
         # to +ve because the FP sign says if it's -ve or not.
-        with m.If(signed):
-            comb += z1.s.eq(self.i.a[-1])      # sign in top bit of a
-            with m.If(z1.s):
-                comb += mantissa.eq(-self.i.a) # invert input if sign -ve
-            with m.Else():
-                comb += mantissa.eq(self.i.a)  # leave as-is
-        with m.Else():
-            comb += mantissa.eq(self.i.a)      # unsigned, use full a
-            comb += z1.s.eq(0)
+        comb += sign.eq(Mux(signed, a[-1], 0)) # sign in top bit of a
+        comb += mantissa.eq(Mux(signed,
+                                Mux(sign, -a,  # invert input if sign -ve
+                                           a), # leave as-is
+                                a))            # unsigned, use full a
 
         # set input from full INT
         comb += msb.m_in.eq(Cat(0, 0, 0, mantissa)) # g/r/s + input
             # smaller int to larger FP
             comb += z1.e.eq(msb.e_out)
             comb += z1.m[ms:].eq(msb.m_out[3:])
-        comb += z1.create(z1.s, z1.e, z1.m) # ... here
+        comb += z1.s.eq(sign)
+        comb += z1.create(sign, z1.e, z1.m) # ... here
 
         # note: post-normalisation actually appears to be capable of
         # detecting overflow to infinity (FPPackMod).  so it's ok to
             comb += self.o.of.sticky.eq(msb.m_out[:1].bool())
             comb += self.o.of.m0.eq(msb.m_out[3])
 
-        # special cases active by default
-        comb += self.o.out_do_z.eq(1)
+        a_nonzero = Signal(reset_less=True)
+        comb += a_nonzero.eq(~a.bool())
+
+        # prepare zero
+        z_zero = FPNumBaseRecord(z1.width, False, name="z_zero")
+        comb += z_zero.zero(0)
+
+        # special cases?
+        comb += self.o.out_do_z.eq(a_nonzero)
 
         # detect zero
-        with m.If(~self.i.a.bool()):
-            comb += self.o.z.zero(0)
-        with m.Else():
-            comb += self.o.out_do_z.eq(0) # activate normalisation
+        comb += self.o.oz.eq(Mux(a_nonzero, z1.v, z_zero.v))
 
         # copy the context (muxid, operator)
-        comb += self.o.oz.eq(self.o.z.v)
         comb += self.o.ctx.eq(self.i.ctx)
 
         return m
 
     # should be fine.
     dut = FPCVTIntMuxInOut(16, 32, 4, op_wid=1)
     runfp(dut, 16, "test_fcvt_int_pipe_i16_f32", to_int16, fcvt_i16_f32, True,
-          n_vals=100, opcode=0x1)
+          n_vals=20, opcode=0x1)
 
 def test_int_pipe_i32_f64():
     dut = FPCVTIntMuxInOut(32, 64, 4, op_wid=1)
     runfp(dut, 32, "test_fcvt_int_pipe_i32_f64", to_int32, fcvt_i32_f64, True,
-          n_vals=100, opcode=0x1)
+          n_vals=20, opcode=0x1)
 
 def test_int_pipe_i32_f32():
     dut = FPCVTIntMuxInOut(32, 32, 4, op_wid=1)
     runfp(dut, 32, "test_fcvt_int_pipe_i32_f32", to_int32, fcvt_i32_f32, True,
-          n_vals=100, opcode=0x1)
+          n_vals=20, opcode=0x1)
 
 ######################
 # unsigned int to fp
     # should be fine.
     dut = FPCVTIntMuxInOut(16, 32, 4, op_wid=1)
     runfp(dut, 16, "test_fcvt_int_pipe_ui16_f32", to_uint16, fcvt_32, True,
-          n_vals=100)
+          n_vals=20)
 
 def test_int_pipe_ui16_f64():
     dut = FPCVTIntMuxInOut(16, 64, 4, op_wid=1)
     runfp(dut, 16, "test_fcvt_int_pipe_ui16_f64", to_uint16, fcvt_64, True,
-          n_vals=100)
+          n_vals=20)
 
 def test_int_pipe_ui32_f32():
     dut = FPCVTIntMuxInOut(32, 32, 4, op_wid=1)
     runfp(dut, 32, "test_fcvt_int_pipe_ui32_32", to_uint32, fcvt_32, True,
-          n_vals=100)
+          n_vals=20)
 
 def test_int_pipe_ui32_f64():
     dut = FPCVTIntMuxInOut(32, 64, 4, op_wid=1)
     runfp(dut, 32, "test_fcvt_int_pipe_ui32_64", to_uint32, fcvt_64, True,
-          n_vals=100)
+          n_vals=20)
 
 def test_int_pipe_ui64_f32():
     # ok, doing 33 bits here because it's pretty pointless (not entirely)
     # converted to Inf
     dut = FPCVTIntMuxInOut(64, 32, 4, op_wid=1)
     runfp(dut, 33, "test_fcvt_int_pipe_ui64_32", to_uint64, fcvt_64_to_32, True,
-          n_vals=100)
+          n_vals=20)
 
 def test_int_pipe_ui64_f16():
     # ok, doing 17 bits here because it's pretty pointless (not entirely)
     # converted to Inf
     dut = FPCVTIntMuxInOut(64, 16, 4, op_wid=1)
     runfp(dut, 17, "test_fcvt_int_pipe_ui64_16", to_uint64, fcvt_16, True,
-          n_vals=100)
+          n_vals=20)
 
 def test_int_pipe_ui32_f16():
     # ok, doing 17 bits here because it's pretty pointless (not entirely)
     # converted to Inf
     dut = FPCVTIntMuxInOut(32, 16, 4, op_wid=1)
     runfp(dut, 17, "test_fcvt_int_pipe_ui32_16", to_uint32, fcvt_16, True,
-          n_vals=100)
+          n_vals=20)
 
 if __name__ == '__main__':
     for i in range(200):
-        test_int_pipe_ui32_f32()
-        test_int_pipe_i32_f32()
-        continue
         test_int_pipe_i16_f32()
         test_int_pipe_i32_f64()
-        continue
+        test_int_pipe_ui32_f32()
+        test_int_pipe_i32_f32()
         test_int_pipe_ui16_f32()
         test_int_pipe_ui64_f32()
         test_int_pipe_ui32_f16()