working on implementing fma, f16 rtz formal proof seems likely to work
[ieee754fpu.git] / src / ieee754 / fpfma / special_cases.py
index 95d3026692465afeba1bf81955b87c3123b9e779..826c32a8e80f2a121a585d7c83bb65e436e97d50 100644 (file)
@@ -3,14 +3,16 @@
 computes `z = (a * c) + b` but only rounds once at the end
 """
 
-from nmutil.pipemodbase import PipeModBase
+from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.basedata import FPBaseData
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from ieee754.fpcommon.getop import FPPipeContext
 from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge
 from ieee754.fpfma.util import expanded_exponent_shape, \
-    expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape
+    expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape, \
+    EXPANDED_MANTISSA_EXTRA_MSBS, EXPANDED_MANTISSA_EXTRA_LSBS, \
+    product_mantissa_shape
 
 
 class FPFMAInputData(FPBaseData):
@@ -52,13 +54,13 @@ class FPFMASpecialCasesDeNormOutData:
         self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
         """mantissa of a input -- un-normalized and with implicit bit added"""
 
-        self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+        self.b_mantissa = Signal(expanded_mantissa_shape(fpf))
         """mantissa of b input
 
         shifted to appropriate location for add and with implicit bit added
         """
 
-        self.c_mantissa = Signal(expanded_mantissa_shape(fpf))
+        self.c_mantissa = Signal(multiplicand_mantissa_shape(fpf))
         """mantissa of c input -- un-normalized and with implicit bit added"""
 
         self.do_sub = Signal()
@@ -123,15 +125,30 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
         out = self.o
 
         a_exponent = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a))
+        m.d.comb += a_exponent.eq(fpf.get_exponent_value(inp.a))
         b_exponent_in = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b))
+        m.d.comb += b_exponent_in.eq(fpf.get_exponent_value(inp.b))
         c_exponent = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c))
+        m.d.comb += c_exponent.eq(fpf.get_exponent_value(inp.c))
+        b_exponent = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += b_exponent.eq(b_exponent_in + EXPANDED_MANTISSA_EXTRA_MSBS)
         prod_exponent = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += prod_exponent.eq(a_exponent + c_exponent)
+
+        # number of bits that the product of two normalized signals needs to
+        # be shifted left to be normalized, e.g. the product of 2 8-bit
+        # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted
+        # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again:
+        # `0x4000 << 1 == 0x8000`
+        PROD_STAY_NORM_SHIFT = 1
+
+        extra_prod_exponent = (expanded_mantissa_shape(fpf).width
+                               - product_mantissa_shape(fpf).width
+                               + PROD_STAY_NORM_SHIFT
+                               - EXPANDED_MANTISSA_EXTRA_LSBS)
+        m.d.comb += prod_exponent.eq(a_exponent + c_exponent
+                                     + extra_prod_exponent)
         prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in)
+        m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent)
         b_mantissa_in = Signal(fpf.fraction_width + 1)
         m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
         p_sign = Signal()
@@ -150,30 +167,37 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
             ]
         with m.Else():
             m.d.comb += [
-                exponent.eq(b_exponent_in),
+                exponent.eq(b_exponent),
                 b_shift.eq(0),
             ]
 
-        m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width)
+        m.submodules.rshiftm = rshiftm = MultiShiftRMerge(
+            out.b_mantissa.width - EXPANDED_MANTISSA_EXTRA_MSBS,
+            s_max=expanded_exponent_shape(fpf).width - 1)
         m.d.comb += [
-            rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width
-                                             - b_mantissa_in.width)),
+            rshiftm.inp.eq(0),
+            rshiftm.inp[-b_mantissa_in.width:].eq(b_mantissa_in),
             rshiftm.diff.eq(b_shift),
         ]
 
+        keep = {"keep": True}
+
         # handle special cases
         with m.If(fpf.is_nan(inp.a)):
             m.d.comb += [
+                Signal(name="case_nan_a", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_nan(inp.b)):
             m.d.comb += [
+                Signal(name="case_nan_b", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_nan(inp.c)):
             m.d.comb += [
+                Signal(name="case_nan_c", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
                 out.do_bypass.eq(True),
             ]
@@ -181,37 +205,50 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
                     | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
             # infinity * 0
             m.d.comb += [
+                Signal(name="case_inf_times_zero", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.quiet_nan()),
                 out.do_bypass.eq(True),
             ]
         with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
-                    & fpf.is_inf(inp.b) & p_sign != b_sign):
+                    & fpf.is_inf(inp.b) & (p_sign != b_sign)):
             # inf - inf
             m.d.comb += [
+                Signal(name="case_inf_minus_inf", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.quiet_nan()),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
             # inf + x
             m.d.comb += [
+                Signal(name="case_inf_plus_x", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.inf(p_sign)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_inf(inp.b)):
             # x + inf
             m.d.comb += [
+                Signal(name="case_x_plus_inf", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.inf(b_sign)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
-                    & fpf.is_zero(inp.b) & p_sign == b_sign):
+                    & fpf.is_zero(inp.b) & (p_sign == b_sign)):
             # zero + zero
             m.d.comb += [
+                Signal(name="case_zero_plus_zero", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.zero(p_sign)),
                 out.do_bypass.eq(True),
             ]
-            # zero - zero handled by FPFMAMainStage
+        with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
+                    & ~fpf.is_zero(inp.b)):
+            # zero + x
+            m.d.comb += [
+                Signal(name="case_zero_plus_x", attrs=keep).eq(True),
+                out.bypassed_z.eq(inp.b),
+                out.do_bypass.eq(True),
+            ]
         with m.Else():
+            # zero - zero handled by FPFMAMainStage
             m.d.comb += [
                 out.bypassed_z.eq(0),
                 out.do_bypass.eq(False),
@@ -229,3 +266,13 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
         ]
 
         return m
+
+
+class FPFMASpecialCasesDeNormStage(PipeModBaseChain):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+
+    def get_chain(self):
+        """ gets chain of modules
+        """
+        return [FPFMASpecialCasesDeNorm(self.pspec)]