move FPModBase and FPModBaseChain to nmutil
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
index 227eaab4a0ca89fdff007d4042ad67472b3bcaa7..5b63702062cfb6428c5aa0dee9b8b1b3bf50ea45 100644 (file)
@@ -9,30 +9,26 @@ Relevant bugreports:
 * http://bugs.libre-riscv.org/show_bug.cgi?id=44
 """
 
-from nmigen import Module, Signal, Cat, Const, Elaboratable
+from nmigen import Module, Signal
 from nmigen.cli import main, verilog
 from math import log
 
+from nmutil.pipemodbase import FPModBase, FPModBaseChain
 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
-from nmutil.singlepipe import SimpleHandshake, StageChain
-
-from ieee754.fpcommon.fpbase import FPState, FPID
 from ieee754.fpcommon.getop import FPADDBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
 from ieee754.fpmul.align import FPAlignModSingle
 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation as DP
 
 
-class FPDIVSpecialCasesMod(Elaboratable):
+class FPDIVSpecialCasesMod(FPModBase):
     """ special cases: NaNs, infs, zeros, denormalised
         see "Special Operations"
         https://steve.hollasch.net/cgindex/coding/ieeefloat.html
     """
 
     def __init__(self, pspec):
-        self.pspec = pspec
-        self.i = self.ispec()
-        self.o = self.ospec()
+        super().__init__(pspec, "specialcases")
 
     def ispec(self):
         return FPADDBaseData(self.pspec)
@@ -40,15 +36,6 @@ class FPDIVSpecialCasesMod(Elaboratable):
     def ospec(self):
         return FPSCData(self.pspec, False)
 
-    def setup(self, m, i):
-        """ links module to inputs and outputs
-        """
-        m.submodules.specialcases = self
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
-
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
@@ -64,42 +51,41 @@ class FPDIVSpecialCasesMod(Elaboratable):
                      self.o.b.eq(b1)
                      ]
 
+        # temporaries (used below)
         sabx = Signal(reset_less=True)   # sign a xor b (sabx, get it?)
-        comb += sabx.eq(a1.s ^ b1.s)
-
         abnan = Signal(reset_less=True)
-        comb += abnan.eq(a1.is_nan | b1.is_nan)
-
         abinf = Signal(reset_less=True)
+
+        comb += sabx.eq(a1.s ^ b1.s)
+        comb += abnan.eq(a1.is_nan | b1.is_nan)
         comb += abinf.eq(a1.is_inf & b1.is_inf)
 
+        # default (overridden if needed)
+        comb += self.o.out_do_z.eq(1)
+
+        # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
         with m.Switch(self.i.ctx.op):
 
             with m.Case(int(DP.UDivRem)): # DIV
 
                 # if a is NaN or b is NaN return NaN
                 with m.If(abnan):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.nan(0)
 
                 # if a is inf and b is Inf return NaN
                 with m.Elif(abinf):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.nan(0)
 
                 # if a is inf return inf
                 with m.Elif(a1.is_inf):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.inf(sabx)
 
                 # if b is inf return zero
                 with m.Elif(b1.is_inf):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.zero(sabx)
 
                 # if a is zero return zero (or NaN if b is zero)
                 with m.Elif(a1.is_zero):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.zero(sabx)
                     # b is zero return NaN
                     with m.If(b1.is_zero):
@@ -107,7 +93,6 @@ class FPDIVSpecialCasesMod(Elaboratable):
 
                 # if b is zero return Inf
                 with m.Elif(b1.is_zero):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.inf(sabx)
 
                 # Denormalised Number checks next, so pass a/b data through
@@ -118,22 +103,18 @@ class FPDIVSpecialCasesMod(Elaboratable):
 
                 # if a is zero return zero
                 with m.If(a1.is_zero):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.zero(a1.s)
 
                 # -ve number is NaN
                 with m.Elif(a1.s):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.nan(0)
 
                 # if a is inf return inf
                 with m.Elif(a1.is_inf):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.inf(sabx)
 
                 # if a is NaN return NaN
                 with m.Elif(a1.is_nan):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.nan(0)
 
                 # Denormalised Number checks next, so pass a/b data through
@@ -144,23 +125,19 @@ class FPDIVSpecialCasesMod(Elaboratable):
 
                 # if a is NaN return canonical NaN
                 with m.If(a1.is_nan):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.nan(0)
 
                 # if a is +/- zero return +/- INF
                 with m.Elif(a1.is_zero):
-                    comb += self.o.out_do_z.eq(1)
                     # this includes the "weird" case 1/sqrt(-0) == -Inf
                     comb += self.o.z.inf(a1.s)
 
                 # -ve number is canonical NaN
                 with m.Elif(a1.s):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.nan(0)
 
                 # if a is inf return zero (-ve already excluded, above)
                 with m.Elif(a1.is_inf):
-                    comb += self.o.out_do_z.eq(1)
                     comb += self.o.z.zero(0)
 
                 # Denormalised Number checks next, so pass a/b data through
@@ -173,71 +150,15 @@ class FPDIVSpecialCasesMod(Elaboratable):
         return m
 
 
-class FPDIVSpecialCases(FPState):
-    """ special cases: NaNs, infs, zeros, denormalised
-        NOTE: some of these are unique to div.  see "Special Operations"
-        https://steve.hollasch.net/cgindex/coding/ieeefloat.html
-    """
-
-    def __init__(self, pspec):
-        FPState.__init__(self, "special_cases")
-        self.mod = FPDIVSpecialCasesMod(pspec)
-        self.out_z = self.mod.ospec()
-        self.out_do_z = Signal(reset_less=True)
-
-    def setup(self, m, i):
-        """ links module to inputs and outputs
-        """
-        self.mod.setup(m, i, self.out_do_z)
-        m.d.sync += self.out_z.v.eq(self.mod.out_z.v)  # only take the output
-        m.d.sync += self.out_z.mid.eq(self.mod.o.mid)  # (and mid)
-
-    def action(self, m):
-        self.idsync(m)
-        with m.If(self.out_do_z):
-            m.next = "put_z"
-        with m.Else():
-            m.next = "denormalise"
-
-
-class FPDIVSpecialCasesDeNorm(FPState, SimpleHandshake):
+class FPDIVSpecialCasesDeNorm(FPModBaseChain):
     """ special cases: NaNs, infs, zeros, denormalised
     """
 
-    def __init__(self, pspec):
-        FPState.__init__(self, "special_cases")
-        self.pspec = pspec
-        SimpleHandshake.__init__(self, self)  # pipe is its own stage
-        self.out = self.ospec()
-
-    def ispec(self):
-        return FPADDBaseData(self.pspec)  # SpecialCases ispec
-
-    def ospec(self):
-        return FPSCData(self.pspec, False)  # Align ospec
-
-    def setup(self, m, i):
+    def get_chain(self):
         """ links module to inputs and outputs
         """
         smod = FPDIVSpecialCasesMod(self.pspec)
         dmod = FPAddDeNormMod(self.pspec, False)
         amod = FPAlignModSingle(self.pspec, False)
 
-        chain = StageChain([smod, dmod, amod])
-        chain.setup(m, i)
-
-        # only needed for break-out (early-out)
-        # self.out_do_z = smod.o.out_do_z
-
-        self.o = amod.o
-
-    def process(self, i):
-        return self.o
-
-    def action(self, m):
-        # for break-out (early-out)
-        #with m.If(self.out_do_z):
-        #    m.next = "put_z"
-        #with m.Else():
-            m.d.sync += self.out.eq(self.process(None))
-            m.next = "align"
+        return [smod, dmod, amod]