fix tests/mark as expected failure
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
index 1cf76b1d18628f3029feed8c61e7d885dc4aefd1..9f925b871493a6885d92366d648ea6c57bb39c20 100644 (file)
@@ -18,9 +18,39 @@ Formulas solved are:
 The remainder is the left-hand-side of the comparison minus the
 right-hand-side of the comparison in the above formulas.
 """
-from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
 from nmigen.lib.coding import PriorityEncoder
+from nmutil.util import treereduce
 import enum
+import operator
+
+
+class DivPipeCoreOperation(enum.Enum):
+    """ Operation for ``DivPipeCore``.
+
+    :attribute UDivRem: unsigned divide/remainder.
+    :attribute SqrtRem: square-root/remainder.
+    :attribute RSqrtRem: reciprocal-square-root/remainder.
+    """
+
+    SqrtRem = 0
+    UDivRem = 1
+    RSqrtRem = 2
+
+    def __int__(self):
+        """ Convert to int. """
+        return self.value
+
+    @classmethod
+    def create_signal(cls, *, src_loc_at=0, **kwargs):
+        """ Create a signal that can contain a ``DivPipeCoreOperation``. """
+        return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
+                      src_loc_at=(src_loc_at + 1),
+                      decoder=lambda v: str(cls(v)),
+                      **kwargs)
+
+
+DP = DivPipeCoreOperation
 
 
 class DivPipeCoreConfig:
@@ -33,17 +63,23 @@ class DivPipeCoreConfig:
         computed per pipeline stage.
     """
 
-    def __init__(self, bit_width, fract_width, log2_radix):
+    def __init__(self, bit_width, fract_width, log2_radix, supported=None):
         """ Create a ``DivPipeCoreConfig`` instance. """
         self.bit_width = bit_width
         self.fract_width = fract_width
         self.log2_radix = log2_radix
+        if supported is None:
+            supported = frozenset(DP)
+        else:
+            supported = frozenset(supported)
+        self.supported = supported
         print(f"{self}: n_stages={self.n_stages}")
 
     def __repr__(self):
         """ Get repr. """
         return f"DivPipeCoreConfig({self.bit_width}, " \
-            + f"{self.fract_width}, {self.log2_radix})"
+            + f"{self.fract_width}, {self.log2_radix}, "\
+            + f"supported={self.supported})"
 
     @property
     def n_stages(self):
@@ -51,35 +87,6 @@ class DivPipeCoreConfig:
         return (self.bit_width + self.log2_radix - 1) // self.log2_radix
 
 
-class DivPipeCoreOperation(enum.Enum):
-    """ Operation for ``DivPipeCore``.
-
-    :attribute UDivRem: unsigned divide/remainder.
-    :attribute SqrtRem: square-root/remainder.
-    :attribute RSqrtRem: reciprocal-square-root/remainder.
-    """
-
-    UDivRem = 0
-    SqrtRem = 1
-    RSqrtRem = 2
-
-    def __int__(self):
-        """ Convert to int. """
-        return self.value
-
-    @classmethod
-    def create_signal(cls, *, src_loc_at=0, **kwargs):
-        """ Create a signal that can contain a ``DivPipeCoreOperation``. """
-        return Signal(min=min(map(int, cls)),
-                      max=max(map(int, cls)) + 2,
-                      src_loc_at=(src_loc_at + 1),
-                      decoder=lambda v: str(cls(v)),
-                      **kwargs)
-
-
-DP = DivPipeCoreOperation
-
-
 class DivPipeCoreInputData:
     """ input data type for ``DivPipeCore``.
 
@@ -97,10 +104,10 @@ class DivPipeCoreInputData:
     def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInputData`` instance. """
         self.core_config = core_config
-        self.dividend = Signal(core_config.bit_width + core_config.fract_width,
-                               reset_less=reset_less)
-        self.divisor_radicand = Signal(core_config.bit_width,
-                                       reset_less=reset_less)
+        bw = core_config.bit_width
+        fw = core_config.fract_width
+        self.dividend = Signal(bw + fw, reset_less=reset_less)
+        self.divisor_radicand = Signal(bw, reset_less=reset_less)
         self.operation = DP.create_signal(reset_less=reset_less)
 
     def __iter__(self):
@@ -145,17 +152,18 @@ class DivPipeCoreInterstageData:
     def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInterstageData`` instance. """
         self.core_config = core_config
-        self.divisor_radicand = Signal(core_config.bit_width,
-                                       reset_less=reset_less)
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
+        self.divisor_radicand = Signal(bw, reset_less=reset_less)
         self.operation = DP.create_signal(reset_less=reset_less)
-        self.quotient_root = Signal(core_config.bit_width,
-                                    reset_less=reset_less)
-        self.root_times_radicand = Signal(core_config.bit_width * 2,
-                                          reset_less=reset_less)
-        self.compare_lhs = Signal(core_config.bit_width * 3,
-                                  reset_less=reset_less)
-        self.compare_rhs = Signal(core_config.bit_width * 3,
-                                  reset_less=reset_less)
+        self.quotient_root = Signal(bw, reset_less=reset_less)
+        self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
+        self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
+        self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
@@ -192,10 +200,14 @@ class DivPipeCoreOutputData:
     def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreOutputData`` instance. """
         self.core_config = core_config
-        self.quotient_root = Signal(core_config.bit_width,
-                                    reset_less=reset_less)
-        self.remainder = Signal(core_config.bit_width * 3,
-                                reset_less=reset_less)
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
+        self.quotient_root = Signal(bw, reset_less=reset_less)
+        self.remainder = Signal(self.compare_len, reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
@@ -217,6 +229,12 @@ class DivPipeCoreSetupStage(Elaboratable):
         self.core_config = core_config
         self.i = self.ispec()
         self.o = self.ospec()
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
 
     def ispec(self):
         """ Get the input spec for this pipeline stage."""
@@ -244,15 +262,16 @@ class DivPipeCoreSetupStage(Elaboratable):
         comb += self.o.quotient_root.eq(0)
         comb += self.o.root_times_radicand.eq(0)
 
-        lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
+        lhs = Signal(self.compare_len, reset_less=True)
         fw = self.core_config.fract_width
 
-        with m.If(self.i.operation == int(DP.UDivRem)):
-            comb += lhs.eq(self.i.dividend << fw)
-        with m.Elif(self.i.operation == int(DP.SqrtRem)):
-            comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
-        with m.Else():  # DivPipeCoreOperation.RSqrtRem
-            comb += lhs.eq(1 << (fw * 3))
+        with m.Switch(self.i.operation):
+            with m.Case(int(DP.UDivRem)):
+                comb += lhs.eq(self.i.dividend << fw)
+            with m.Case(int(DP.SqrtRem)):
+                comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
+            with m.Case(int(DP.RSqrtRem)):
+                comb += lhs.eq(1 << (fw * 3))
 
         comb += self.o.compare_lhs.eq(lhs)
         comb += self.o.compare_rhs.eq(0)
@@ -268,11 +287,16 @@ class Trial(Elaboratable):
         self.current_shift = current_shift
         self.log2_radix = log2_radix
         bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
         self.divisor_radicand = Signal(bw, reset_less=True)
         self.quotient_root = Signal(bw, reset_less=True)
         self.root_times_radicand = Signal(bw * 2, reset_less=True)
-        self.compare_rhs = Signal(bw * 3, reset_less=True)
-        self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
+        self.compare_rhs = Signal(self.compare_len, reset_less=True)
+        self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
         self.operation = DP.create_signal(reset_less=True)
 
     def elaborate(self, platform):
@@ -280,63 +304,67 @@ class Trial(Elaboratable):
         m = Module()
         comb = m.d.comb
 
+        cc = self.core_config
         dr = self.divisor_radicand
-        qr = self.quotient_root
-        rr = self.root_times_radicand
 
         trial_bits_sig = Const(self.trial_bits, self.log2_radix)
         trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
                                     self.log2_radix * 2)
 
         tblen = self.core_config.bit_width+self.log2_radix
-        tblen2 = self.core_config.bit_width+self.log2_radix*2
-        dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
-        comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
 
         # UDivRem
-        with m.If(self.operation == int(DP.UDivRem)):
-            dr_times_trial_bits = Signal(tblen, reset_less=True)
-            comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
-            div_rhs = self.compare_rhs
+        if DP.UDivRem in cc.supported:
+            with m.If(self.operation == int(DP.UDivRem)):
+                dr_times_trial_bits = Signal(tblen, reset_less=True)
+                comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
+                div_rhs = self.compare_rhs
 
-            div_term1 = dr_times_trial_bits
-            div_term1_shift = self.core_config.fract_width
-            div_term1_shift += self.current_shift
-            div_rhs += div_term1 << div_term1_shift
+                div_term1 = dr_times_trial_bits
+                div_term1_shift = self.core_config.fract_width
+                div_term1_shift += self.current_shift
+                div_rhs += div_term1 << div_term1_shift
 
-            comb += self.trial_compare_rhs.eq(div_rhs)
+                comb += self.trial_compare_rhs.eq(div_rhs)
 
         # SqrtRem
-        with m.Elif(self.operation == int(DP.SqrtRem)):
-            qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
-            comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
-            sqrt_rhs = self.compare_rhs
-
-            sqrt_term1 = qr_times_trial_bits
-            sqrt_term1_shift = self.core_config.fract_width
-            sqrt_term1_shift += self.current_shift + 1
-            sqrt_rhs += sqrt_term1 << sqrt_term1_shift
-            sqrt_term2 = trial_bits_sqrd_sig
-            sqrt_term2_shift = self.core_config.fract_width
-            sqrt_term2_shift += self.current_shift * 2
-            sqrt_rhs += sqrt_term2 << sqrt_term2_shift
-
-            comb += self.trial_compare_rhs.eq(sqrt_rhs)
+        if DP.SqrtRem in cc.supported:
+            with m.If(self.operation == int(DP.SqrtRem)):
+                qr = self.quotient_root
+                qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
+                comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
+                sqrt_rhs = self.compare_rhs
+
+                sqrt_term1 = qr_times_trial_bits
+                sqrt_term1_shift = self.core_config.fract_width
+                sqrt_term1_shift += self.current_shift + 1
+                sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+                sqrt_term2 = trial_bits_sqrd_sig
+                sqrt_term2_shift = self.core_config.fract_width
+                sqrt_term2_shift += self.current_shift * 2
+                sqrt_rhs += sqrt_term2 << sqrt_term2_shift
+
+                comb += self.trial_compare_rhs.eq(sqrt_rhs)
 
         # RSqrtRem
-        with m.Else():
-            rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
-            comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
-            rsqrt_rhs = self.compare_rhs
-
-            rsqrt_term1 = rr_times_trial_bits
-            rsqrt_term1_shift = self.current_shift + 1
-            rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
-            rsqrt_term2 = dr_times_trial_bits_sqrd
-            rsqrt_term2_shift = self.current_shift * 2
-            rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
-
-            comb += self.trial_compare_rhs.eq(rsqrt_rhs)
+        if DP.RSqrtRem in cc.supported:
+            with m.If(self.operation == int(DP.RSqrtRem)):
+                rr = self.root_times_radicand
+                tblen2 = self.core_config.bit_width+self.log2_radix*2
+                dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
+                comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
+                rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
+                comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
+                rsqrt_rhs = self.compare_rhs
+
+                rsqrt_term1 = rr_times_trial_bits
+                rsqrt_term1_shift = self.current_shift + 1
+                rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+                rsqrt_term2 = dr_times_trial_bits_sqrd
+                rsqrt_term2_shift = self.current_shift * 2
+                rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
+
+                comb += self.trial_compare_rhs.eq(rsqrt_rhs)
 
         return m
 
@@ -346,8 +374,14 @@ class DivPipeCoreCalculateStage(Elaboratable):
 
     def __init__(self, core_config, stage_index):
         """ Create a ``DivPipeCoreSetupStage`` instance. """
-        self.core_config = core_config
         assert stage_index in range(core_config.n_stages)
+        self.core_config = core_config
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
         self.stage_index = stage_index
         self.i = self.ispec()
         self.o = self.ospec()
@@ -375,6 +409,7 @@ class DivPipeCoreCalculateStage(Elaboratable):
         """ Elaborate into ``Module``. """
         m = Module()
         comb = m.d.comb
+        cc = self.core_config
 
         # copy invariant inputs to outputs (for next stage)
         comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
@@ -408,13 +443,17 @@ class DivPipeCoreCalculateStage(Elaboratable):
             comb += t.compare_rhs.eq(self.i.compare_rhs)
             comb += t.operation.eq(self.i.operation)
 
-            # get the trial output
+            # get the trial output (needed even in pass_flags[0] case)
             trial_compare_rhs_values.append(t.trial_compare_rhs)
 
             # make the trial comparison against the [invariant] lhs.
             # trial_compare_rhs is always decreasing as trial_bits increases
             pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
-            comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
+            if trial_bits == 0:
+                # do not do first comparison: no point.
+                comb += pass_flag.eq(1)
+            else:
+                comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
             pfl.append(pass_flag)
 
         # Cat all the pass flags list together (easier to handle, below)
@@ -437,16 +476,24 @@ class DivPipeCoreCalculateStage(Elaboratable):
         with m.Else():
             comb += next_bits.eq(radix-1)
 
-        # get the highest passing rhs trial (indexed by next_bits)
-        ta = Array(trial_compare_rhs_values)
-        comb += self.o.compare_rhs.eq(ta[next_bits])
+        # get the highest passing rhs trial. use treereduce because
+        # Array on such massively long numbers is insanely gate-hungry
+        crhs = []
+        tcrh = trial_compare_rhs_values
+        for i in range(radix):
+            nbe = Signal(reset_less=True)
+            comb += nbe.eq(next_bits == i)
+            crhs.append(Repl(nbe, self.compare_len) & tcrh[i])
+        comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
+                                      lambda x:x))
 
         # create outputs for next phase
         qr = self.i.quotient_root | (next_bits << current_shift)
-        rr = self.i.root_times_radicand + ((self.i.divisor_radicand * next_bits)
-                                                     << current_shift)
         comb += self.o.quotient_root.eq(qr)
-        comb += self.o.root_times_radicand.eq(rr)
+        if DP.RSqrtRem in cc.supported:
+            rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
+                                               next_bits) << current_shift)
+            comb += self.o.root_times_radicand.eq(rr)
 
         return m