correct FPRSQRT specialcases
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
index 5bbe5c40bd5db88fbd2644f3cfbd60be5cca1d68..75721de27ee9508e06e8b57ea30ef82d3c1673ba 100644 (file)
@@ -1,4 +1,4 @@
-# IEEE Floating Point Multiplier 
+# IEEE Floating Point Multiplier
 
 from nmigen import Module, Signal, Cat, Const, Elaboratable
 from nmigen.cli import main, verilog
@@ -10,6 +10,7 @@ from nmutil.singlepipe import SimpleHandshake, StageChain
 from ieee754.fpcommon.fpbase import FPState, FPID
 from ieee754.fpcommon.getop import FPADDBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
+from ieee754.fpmul.align import FPAlignModSingle
 
 
 class FPDIVSpecialCasesMod(Elaboratable):
@@ -44,15 +45,15 @@ class FPDIVSpecialCasesMod(Elaboratable):
         #m.submodules.sc_out_z = self.o.z
 
         # decode: XXX really should move to separate stage
-        a1 = FPNumBaseRecord(self.pspec.width, False)
-        b1 = FPNumBaseRecord(self.pspec.width, False)
+        a1 = FPNumBaseRecord(self.pspec.width, False, name="a1")
+        b1 = FPNumBaseRecord(self.pspec.width, False, name="b1")
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
         m.d.comb += [a1.v.eq(self.i.a),
                      b1.v.eq(self.i.b),
                      self.o.a.eq(a1),
                      self.o.b.eq(b1)
-                    ]
+                     ]
 
         sabx = Signal(reset_less=True)   # sign a xor b (sabx, get it?)
         m.d.comb += sabx.eq(a1.s ^ b1.s)
@@ -63,42 +64,96 @@ class FPDIVSpecialCasesMod(Elaboratable):
         abinf = Signal(reset_less=True)
         m.d.comb += abinf.eq(a1.is_inf & b1.is_inf)
 
-        # if a is NaN or b is NaN return NaN
-        with m.If(abnan):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.nan(0)
-
-        # if a is inf and b is Inf return NaN
-        with m.Elif(abinf):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.nan(0)
-
-        # if a is inf return inf
-        with m.Elif(a1.is_inf):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.inf(sabx)
-
-        # if b is inf return zero
-        with m.Elif(b1.is_inf):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.zero(sabx)
-
-        # if a is zero return zero (or NaN if b is zero)
-        with m.Elif(a1.is_zero):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.zero(sabx)
-            # b is zero return NaN
-            with m.If(b1.is_zero):
+        with m.If(self.i.ctx.op == 0):  # DIV
+            # if a is NaN or b is NaN return NaN
+            with m.If(abnan):
+                m.d.comb += self.o.out_do_z.eq(1)
                 m.d.comb += self.o.z.nan(0)
 
-        # if b is zero return Inf
-        with m.Elif(b1.is_zero):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.inf(sabx)
+            # if a is inf and b is Inf return NaN
+            with m.Elif(abinf):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.nan(0)
 
-        # Denormalised Number checks next, so pass a/b data through
-        with m.Else():
-            m.d.comb += self.o.out_do_z.eq(0)
+            # if a is inf return inf
+            with m.Elif(a1.is_inf):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.inf(sabx)
+
+            # if b is inf return zero
+            with m.Elif(b1.is_inf):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.zero(sabx)
+
+            # if a is zero return zero (or NaN if b is zero)
+            with m.Elif(a1.is_zero):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.zero(sabx)
+                # b is zero return NaN
+                with m.If(b1.is_zero):
+                    m.d.comb += self.o.z.nan(0)
+
+            # if b is zero return Inf
+            with m.Elif(b1.is_zero):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.inf(sabx)
+
+            # Denormalised Number checks next, so pass a/b data through
+            with m.Else():
+                m.d.comb += self.o.out_do_z.eq(0)
+
+        with m.If(self.i.ctx.op == 1):  # SQRT
+
+            # if a is zero return zero
+            with m.If(a1.is_zero):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.zero(a1.s)
+
+            # -ve number is NaN
+            with m.Elif(a1.s):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.nan(0)
+
+            # if a is inf return inf
+            with m.Elif(a1.is_inf):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.inf(sabx)
+
+            # if a is NaN return NaN
+            with m.Elif(a1.is_nan):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.nan(0)
+
+            # Denormalised Number checks next, so pass a/b data through
+            with m.Else():
+                m.d.comb += self.o.out_do_z.eq(0)
+
+        with m.If(self.i.ctx.op == 2):  # RSQRT
+
+            # if a is NaN return canonical NaN
+            with m.If(a1.is_nan):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.nan(0)
+
+            # if a is +/- zero return +/- INF
+            with m.Elif(a1.is_zero):
+                m.d.comb += self.o.out_do_z.eq(1)
+                # this includes the "weird" case 1/sqrt(-0) == -Inf
+                m.d.comb += self.o.z.inf(a1.s)
+
+            # -ve number is canonical NaN
+            with m.Elif(a1.s):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.nan(0)
+
+            # if a is inf return zero (-ve already excluded, above)
+            with m.Elif(a1.is_inf):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.zero(0)
+
+            # Denormalised Number checks next, so pass a/b data through
+            with m.Else():
+                m.d.comb += self.o.out_do_z.eq(0)
 
         m.d.comb += self.o.oz.eq(self.o.z.v)
         m.d.comb += self.o.ctx.eq(self.i.ctx)
@@ -122,7 +177,7 @@ class FPDIVSpecialCases(FPState):
         """ links module to inputs and outputs
         """
         self.mod.setup(m, i, self.out_do_z)
-        m.d.sync += self.out_z.v.eq(self.mod.out_z.v) # only take the output
+        m.d.sync += self.out_z.v.eq(self.mod.out_z.v)  # only take the output
         m.d.sync += self.out_z.mid.eq(self.mod.o.mid)  # (and mid)
 
     def action(self, m):
@@ -140,28 +195,29 @@ class FPDIVSpecialCasesDeNorm(FPState, SimpleHandshake):
     def __init__(self, pspec):
         FPState.__init__(self, "special_cases")
         self.pspec = pspec
-        SimpleHandshake.__init__(self, self) # pipe is its own stage
+        SimpleHandshake.__init__(self, self)  # pipe is its own stage
         self.out = self.ospec()
 
     def ispec(self):
-        return FPADDBaseData(self.pspec) # SpecialCases ispec
+        return FPADDBaseData(self.pspec)  # SpecialCases ispec
 
     def ospec(self):
-        return FPSCData(self.pspec, False) # DeNorm ospec
+        return FPSCData(self.pspec, False)  # Align ospec
 
     def setup(self, m, i):
         """ links module to inputs and outputs
         """
         smod = FPDIVSpecialCasesMod(self.pspec)
         dmod = FPAddDeNormMod(self.pspec, False)
+        amod = FPAlignModSingle(self.pspec, False)
 
-        chain = StageChain([smod, dmod])
+        chain = StageChain([smod, dmod, amod])
         chain.setup(m, i)
 
         # only needed for break-out (early-out)
         # self.out_do_z = smod.o.out_do_z
 
-        self.o = dmod.o
+        self.o = amod.o
 
     def process(self, i):
         return self.o
@@ -173,5 +229,3 @@ class FPDIVSpecialCasesDeNorm(FPState, SimpleHandshake):
         #with m.Else():
             m.d.sync += self.out.eq(self.process(None))
             m.next = "align"
-
-