remove usage of m.If() in fpmax and replace with Mux
authorMichael Nolan <mtnolan2640@gmail.com>
Tue, 28 Jan 2020 20:58:12 +0000 (15:58 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Tue, 28 Jan 2020 21:05:10 +0000 (16:05 -0500)
src/ieee754/fpmax/fpmax.py

index 5c91976d65a526177eab4a7e9cbca5959be732f8..7d3a0d065016b6375421af4273e72f94efc7eb27 100644 (file)
@@ -3,7 +3,7 @@
 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
 
 
-from nmigen import Module, Signal, Cat, Mux
+from nmigen import Module, Signal, Mux
 
 from nmutil.pipemodbase import PipeModBase
 from ieee754.fpcommon.basedata import FPBaseData
@@ -46,27 +46,49 @@ class FPMAXPipeMod(PipeModBase):
         m.d.comb += [a1.v.eq(self.i.a),
                      b1.v.eq(self.i.b)]
 
+        no_nans = Signal(width)
+        some_nans = Signal(width)
+
+        # Handle NaNs
         has_nan = Signal()
         comb += has_nan.eq(a1.is_nan | b1.is_nan)
         both_nan = Signal()
         comb += both_nan.eq(a1.is_nan & b1.is_nan)
-        with m.If(has_nan):
-            with m.If(both_nan):
-                comb += z1.eq(a1.fp.nan2(0))
-            with m.Else():
-                comb += z1.eq(Mux(a1.is_nan, self.i.b, self.i.a))
-        with m.Else():
-            with m.If(a1.s != b1.s):
-                
-                comb += z1.eq(Mux(a1.s ^ opcode[0], self.i.b, self.i.a))
-            with m.Else():
-                gt = Signal()
-                sign = Signal()
-                comb += sign.eq(a1.s)
-                comb += gt.eq(a1.v > b1.v)
-                comb += z1.eq(Mux(gt ^ sign ^ opcode[0],
+
+        # if(both_nan):
+        #     some_nans = NaN - created from scratch
+        # else:
+        #     some_nans = Mux(a1.is_nan, b, a)
+        comb += some_nans.eq(Mux(both_nan,
+                                 a1.fp.nan2(0),
+                                 Mux(a1.is_nan, self.i.b, self.i.a)))
+
+        # if sign(a) != sign(b):
+        #    no_nans = Mux(a1.s ^ opcode[0], b, a)
+        signs_different = Signal()
+        comb += signs_different.eq(a1.s != b1.s)
+
+        signs_different_value = Signal(width)
+        comb += signs_different_value.eq(Mux(a1.s ^ opcode[0],
+                                             self.i.b,
+                                             self.i.a))
+
+        # else:
+        #    if a.v > b.v:
+        #        no_nans = Mux(opcode[0], b, a)
+        #    else:
+        #        no_nans = Mux(opcode[0], a, b)
+        gt = Signal()
+        sign = Signal()
+        signs_same = Signal(width)
+        comb += sign.eq(a1.s)
+        comb += gt.eq(a1.v > b1.v)
+        comb += signs_same.eq(Mux(gt ^ sign ^ opcode[0],
                                   self.i.a, self.i.b))
-                              
+        comb += no_nans.eq(Mux(signs_different, signs_different_value,
+                               signs_same))
+
+        comb += z1.eq(Mux(has_nan, some_nans, no_nans))
 
         # copy the context (muxid, operator)
         comb += self.o.ctx.eq(self.i.ctx)