fix tests/mark as expected failure
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
index 0f76d7125fbe80a7b133ba290dc40ba35e17298d..9f925b871493a6885d92366d648ea6c57bb39c20 100644 (file)
@@ -18,8 +18,39 @@ Formulas solved are:
 The remainder is the left-hand-side of the comparison minus the
 right-hand-side of the comparison in the above formulas.
 """
-from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
+from nmigen.lib.coding import PriorityEncoder
+from nmutil.util import treereduce
 import enum
+import operator
+
+
+class DivPipeCoreOperation(enum.Enum):
+    """ Operation for ``DivPipeCore``.
+
+    :attribute UDivRem: unsigned divide/remainder.
+    :attribute SqrtRem: square-root/remainder.
+    :attribute RSqrtRem: reciprocal-square-root/remainder.
+    """
+
+    SqrtRem = 0
+    UDivRem = 1
+    RSqrtRem = 2
+
+    def __int__(self):
+        """ Convert to int. """
+        return self.value
+
+    @classmethod
+    def create_signal(cls, *, src_loc_at=0, **kwargs):
+        """ Create a signal that can contain a ``DivPipeCoreOperation``. """
+        return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
+                      src_loc_at=(src_loc_at + 1),
+                      decoder=lambda v: str(cls(v)),
+                      **kwargs)
+
+
+DP = DivPipeCoreOperation
 
 
 class DivPipeCoreConfig:
@@ -32,45 +63,30 @@ class DivPipeCoreConfig:
         computed per pipeline stage.
     """
 
-    def __init__(self, bit_width, fract_width, log2_radix):
+    def __init__(self, bit_width, fract_width, log2_radix, supported=None):
         """ Create a ``DivPipeCoreConfig`` instance. """
         self.bit_width = bit_width
         self.fract_width = fract_width
         self.log2_radix = log2_radix
+        if supported is None:
+            supported = frozenset(DP)
+        else:
+            supported = frozenset(supported)
+        self.supported = supported
+        print(f"{self}: n_stages={self.n_stages}")
 
     def __repr__(self):
         """ Get repr. """
         return f"DivPipeCoreConfig({self.bit_width}, " \
-            + f"{self.fract_width}, {self.log2_radix})"
+            + f"{self.fract_width}, {self.log2_radix}, "\
+            + f"supported={self.supported})"
 
     @property
-    def num_calculate_stages(self):
+    def n_stages(self):
         """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
         return (self.bit_width + self.log2_radix - 1) // self.log2_radix
 
 
-class DivPipeCoreOperation(enum.IntEnum):
-    """ Operation for ``DivPipeCore``.
-
-    :attribute UDivRem: unsigned divide/remainder.
-    :attribute SqrtRem: square-root/remainder.
-    :attribute RSqrtRem: reciprocal-square-root/remainder.
-    """
-
-    UDivRem = 0
-    SqrtRem = 1
-    RSqrtRem = 2
-
-    @classmethod
-    def create_signal(cls, *, src_loc_at=0, **kwargs):
-        """ Create a signal that can contain a ``DivPipeCoreOperation``. """
-        return Signal(min=int(min(cls)),
-                      max=int(max(cls)),
-                      src_loc_at=(src_loc_at + 1),
-                      decoder=cls,
-                      **kwargs)
-
-
 class DivPipeCoreInputData:
     """ input data type for ``DivPipeCore``.
 
@@ -88,26 +104,23 @@ class DivPipeCoreInputData:
     def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInputData`` instance. """
         self.core_config = core_config
-        self.dividend = Signal(core_config.bit_width + core_config.fract_width,
-                               reset_less=reset_less)
-        self.divisor_radicand = Signal(core_config.bit_width,
-                                       reset_less=reset_less)
-
-        # FIXME: this goes into (is replaced by) self.ctx.op
-        self.operation = \
-            DivPipeCoreOperation.create_signal(reset_less=reset_less)
+        bw = core_config.bit_width
+        fw = core_config.fract_width
+        self.dividend = Signal(bw + fw, reset_less=reset_less)
+        self.divisor_radicand = Signal(bw, reset_less=reset_less)
+        self.operation = DP.create_signal(reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
         yield self.dividend
         yield self.divisor_radicand
-        yield self.operation  # FIXME: delete.  already covered by self.ctx
+        yield self.operation
 
     def eq(self, rhs):
         """ Assign member signals. """
         return [self.dividend.eq(rhs.dividend),
                 self.divisor_radicand.eq(rhs.divisor_radicand),
-                self.operation.eq(rhs.operation),  # FIXME: delete.
+                self.operation.eq(rhs.operation),
                 ]
 
 
@@ -139,24 +152,23 @@ class DivPipeCoreInterstageData:
     def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInterstageData`` instance. """
         self.core_config = core_config
-        self.divisor_radicand = Signal(core_config.bit_width,
-                                       reset_less=reset_less)
-        # FIXME: delete self.operation.  already covered by self.ctx.op
-        self.operation = \
-            DivPipeCoreOperation.create_signal(reset_less=reset_less)
-        self.quotient_root = Signal(core_config.bit_width,
-                                    reset_less=reset_less)
-        self.root_times_radicand = Signal(core_config.bit_width * 2,
-                                          reset_less=reset_less)
-        self.compare_lhs = Signal(core_config.bit_width * 3,
-                                  reset_less=reset_less)
-        self.compare_rhs = Signal(core_config.bit_width * 3,
-                                  reset_less=reset_less)
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
+        self.divisor_radicand = Signal(bw, reset_less=reset_less)
+        self.operation = DP.create_signal(reset_less=reset_less)
+        self.quotient_root = Signal(bw, reset_less=reset_less)
+        self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
+        self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
+        self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
         yield self.divisor_radicand
-        yield self.operation  # FIXME: delete.  already in self.ctx.op
+        yield self.operation
         yield self.quotient_root
         yield self.root_times_radicand
         yield self.compare_lhs
@@ -165,7 +177,7 @@ class DivPipeCoreInterstageData:
     def eq(self, rhs):
         """ Assign member signals. """
         return [self.divisor_radicand.eq(rhs.divisor_radicand),
-                self.operation.eq(rhs.operation),  # FIXME: delete.
+                self.operation.eq(rhs.operation),
                 self.quotient_root.eq(rhs.quotient_root),
                 self.root_times_radicand.eq(rhs.root_times_radicand),
                 self.compare_lhs.eq(rhs.compare_lhs),
@@ -188,10 +200,14 @@ class DivPipeCoreOutputData:
     def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreOutputData`` instance. """
         self.core_config = core_config
-        self.quotient_root = Signal(core_config.bit_width,
-                                    reset_less=reset_less)
-        self.remainder = Signal(core_config.bit_width * 3,
-                                reset_less=reset_less)
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
+        self.quotient_root = Signal(bw, reset_less=reset_less)
+        self.remainder = Signal(self.compare_len, reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
@@ -213,6 +229,12 @@ class DivPipeCoreSetupStage(Elaboratable):
         self.core_config = core_config
         self.i = self.ispec()
         self.o = self.ospec()
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
 
     def ispec(self):
         """ Get the input spec for this pipeline stage."""
@@ -234,23 +256,115 @@ class DivPipeCoreSetupStage(Elaboratable):
     def elaborate(self, platform):
         """ Elaborate into ``Module``. """
         m = Module()
+        comb = m.d.comb
+
+        comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
+        comb += self.o.quotient_root.eq(0)
+        comb += self.o.root_times_radicand.eq(0)
+
+        lhs = Signal(self.compare_len, reset_less=True)
+        fw = self.core_config.fract_width
+
+        with m.Switch(self.i.operation):
+            with m.Case(int(DP.UDivRem)):
+                comb += lhs.eq(self.i.dividend << fw)
+            with m.Case(int(DP.SqrtRem)):
+                comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
+            with m.Case(int(DP.RSqrtRem)):
+                comb += lhs.eq(1 << (fw * 3))
 
-        m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
-        m.d.comb += self.o.quotient_root.eq(0)
-        m.d.comb += self.o.root_times_radicand.eq(0)
+        comb += self.o.compare_lhs.eq(lhs)
+        comb += self.o.compare_rhs.eq(0)
+        comb += self.o.operation.eq(self.i.operation)
+
+        return m
 
-        with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
-            m.d.comb += self.o.compare_lhs.eq(self.i.dividend
-                                              << self.core_config.fract_width)
-        with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
-            m.d.comb += self.o.compare_lhs.eq(
-                self.i.divisor_radicand << (self.core_config.fract_width * 2))
-        with m.Else():  # DivPipeCoreOperation.RSqrtRem
-            m.d.comb += self.o.compare_lhs.eq(
-                1 << (self.core_config.fract_width * 3))
 
-        m.d.comb += self.o.compare_rhs.eq(0)
-        m.d.comb += self.o.operation.eq(self.i.operation)
+class Trial(Elaboratable):
+    def __init__(self, core_config, trial_bits, current_shift, log2_radix):
+        self.core_config = core_config
+        self.trial_bits = trial_bits
+        self.current_shift = current_shift
+        self.log2_radix = log2_radix
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
+        self.divisor_radicand = Signal(bw, reset_less=True)
+        self.quotient_root = Signal(bw, reset_less=True)
+        self.root_times_radicand = Signal(bw * 2, reset_less=True)
+        self.compare_rhs = Signal(self.compare_len, reset_less=True)
+        self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
+        self.operation = DP.create_signal(reset_less=True)
+
+    def elaborate(self, platform):
+
+        m = Module()
+        comb = m.d.comb
+
+        cc = self.core_config
+        dr = self.divisor_radicand
+
+        trial_bits_sig = Const(self.trial_bits, self.log2_radix)
+        trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
+                                    self.log2_radix * 2)
+
+        tblen = self.core_config.bit_width+self.log2_radix
+
+        # UDivRem
+        if DP.UDivRem in cc.supported:
+            with m.If(self.operation == int(DP.UDivRem)):
+                dr_times_trial_bits = Signal(tblen, reset_less=True)
+                comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
+                div_rhs = self.compare_rhs
+
+                div_term1 = dr_times_trial_bits
+                div_term1_shift = self.core_config.fract_width
+                div_term1_shift += self.current_shift
+                div_rhs += div_term1 << div_term1_shift
+
+                comb += self.trial_compare_rhs.eq(div_rhs)
+
+        # SqrtRem
+        if DP.SqrtRem in cc.supported:
+            with m.If(self.operation == int(DP.SqrtRem)):
+                qr = self.quotient_root
+                qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
+                comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
+                sqrt_rhs = self.compare_rhs
+
+                sqrt_term1 = qr_times_trial_bits
+                sqrt_term1_shift = self.core_config.fract_width
+                sqrt_term1_shift += self.current_shift + 1
+                sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+                sqrt_term2 = trial_bits_sqrd_sig
+                sqrt_term2_shift = self.core_config.fract_width
+                sqrt_term2_shift += self.current_shift * 2
+                sqrt_rhs += sqrt_term2 << sqrt_term2_shift
+
+                comb += self.trial_compare_rhs.eq(sqrt_rhs)
+
+        # RSqrtRem
+        if DP.RSqrtRem in cc.supported:
+            with m.If(self.operation == int(DP.RSqrtRem)):
+                rr = self.root_times_radicand
+                tblen2 = self.core_config.bit_width+self.log2_radix*2
+                dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
+                comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
+                rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
+                comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
+                rsqrt_rhs = self.compare_rhs
+
+                rsqrt_term1 = rr_times_trial_bits
+                rsqrt_term1_shift = self.current_shift + 1
+                rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+                rsqrt_term2 = dr_times_trial_bits_sqrd
+                rsqrt_term2_shift = self.current_shift * 2
+                rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
+
+                comb += self.trial_compare_rhs.eq(rsqrt_rhs)
 
         return m
 
@@ -260,8 +374,14 @@ class DivPipeCoreCalculateStage(Elaboratable):
 
     def __init__(self, core_config, stage_index):
         """ Create a ``DivPipeCoreSetupStage`` instance. """
+        assert stage_index in range(core_config.n_stages)
         self.core_config = core_config
-        assert stage_index in range(core_config.num_calculate_stages)
+        bw = core_config.bit_width
+        # TODO(programmerjake): re-enable once bit_width reduction is fixed
+        if False and core_config.supported == {DP.UDivRem}:
+            self.compare_len = bw * 2
+        else:
+            self.compare_len = bw * 3
         self.stage_index = stage_index
         self.i = self.ispec()
         self.o = self.ospec()
@@ -288,61 +408,59 @@ class DivPipeCoreCalculateStage(Elaboratable):
     def elaborate(self, platform):
         """ Elaborate into ``Module``. """
         m = Module()
-        m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
-        m.d.comb += self.o.operation.eq(self.i.operation)
-        m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
+        comb = m.d.comb
+        cc = self.core_config
+
+        # copy invariant inputs to outputs (for next stage)
+        comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
+        comb += self.o.operation.eq(self.i.operation)
+        comb += self.o.compare_lhs.eq(self.i.compare_lhs)
+
+        # constants
         log2_radix = self.core_config.log2_radix
         current_shift = self.core_config.bit_width
         current_shift -= self.stage_index * log2_radix
         log2_radix = min(log2_radix, current_shift)
         assert log2_radix > 0
         current_shift -= log2_radix
+        print(f"DivPipeCoreCalc: stage {self.stage_index}"
+              + f" of {self.core_config.n_stages} handling "
+              + f"bits [{current_shift}, {current_shift+log2_radix})"
+              + f" of {self.core_config.bit_width}")
         radix = 1 << log2_radix
+
+        # trials within this radix range.  carried out by Trial module,
+        # results stored in pass_flags.  pass_flags are unary priority.
         trial_compare_rhs_values = []
-        pass_flags = []
+        pfl = []
         for trial_bits in range(radix):
-            tb = trial_bits << current_shift
-            log2_tb = log2_radix + current_shift
-            shifted_trial_bits = Const(tb, log2_tb)
-            shifted_trial_bits2 = Const(tb*2, log2_tb+1)
-            shifted_trial_bits_sqrd = Const(tb * tb, log2_tb * 2)
-
-            # UDivRem
-            div_rhs = self.i.compare_rhs
-            if tb != 0: # no point adding stuff that's multiplied by zero
-                div_factor1 = self.i.divisor_radicand * shifted_trial_bits2
-                div_rhs += div_factor1 << self.core_config.fract_width
-
-            # SqrtRem
-            sqrt_rhs = self.i.compare_rhs
-            if tb != 0: # no point adding stuff that's multiplied by zero
-                sqrt_factor1 = self.i.quotient_root * shifted_trial_bits2
-                sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
-                sqrt_factor2 = shifted_trial_bits_sqrd
-                sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
-
-            # RSqrtRem
-            rsqrt_rhs = self.i.compare_rhs
-            if tb != 0: # no point adding stuff that's multiplied by zero
-                rsqrt_rhs += self.i.root_times_radicand * shifted_trial_bits2
-                rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
-
-            trial_compare_rhs = Signal.like(
-                self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}")
-
-            with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
-                m.d.comb += trial_compare_rhs.eq(div_rhs)
-            with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
-                m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
-            with m.Else():  # DivPipeCoreOperation.RSqrtRem
-                m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
-            trial_compare_rhs_values.append(trial_compare_rhs)
-
-            pass_flag = Signal(name=f"pass_flag_{trial_bits}")
-            m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
-            pass_flags.append(pass_flag)
-
-        # convert pass_flags to next_bits.
+            t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
+            setattr(m.submodules, "trial%d" % trial_bits, t)
+
+            comb += t.divisor_radicand.eq(self.i.divisor_radicand)
+            comb += t.quotient_root.eq(self.i.quotient_root)
+            comb += t.root_times_radicand.eq(self.i.root_times_radicand)
+            comb += t.compare_rhs.eq(self.i.compare_rhs)
+            comb += t.operation.eq(self.i.operation)
+
+            # get the trial output (needed even in pass_flags[0] case)
+            trial_compare_rhs_values.append(t.trial_compare_rhs)
+
+            # make the trial comparison against the [invariant] lhs.
+            # trial_compare_rhs is always decreasing as trial_bits increases
+            pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
+            if trial_bits == 0:
+                # do not do first comparison: no point.
+                comb += pass_flag.eq(1)
+            else:
+                comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
+            pfl.append(pass_flag)
+
+        # Cat all the pass flags list together (easier to handle, below)
+        pass_flags = Signal(radix, reset_less=True)
+        comb += pass_flags.eq(Cat(*pfl))
+
+        # convert pass_flags (unary priority) to next_bits (binary index)
         #
         # Assumes that for each set bit in pass_flag, all previous bits are
         # also set.
@@ -350,32 +468,33 @@ class DivPipeCoreCalculateStage(Elaboratable):
         # Assumes that pass_flag[0] is always set (since
         # compare_lhs >= compare_rhs is a pipeline invariant).
 
-        next_bits = Signal(log2_radix)
-        for i in range(log2_radix):
-            bit_value = 1
-            for j in range(0, radix, 1 << i):
-                bit_value ^= pass_flags[j]
-            m.d.comb += next_bits.part(i, 1).eq(bit_value)
-
-        next_compare_rhs = Signal(radix, reset_less=True)
-        l = []
+        m.submodules.pe = pe = PriorityEncoder(radix)
+        next_bits = Signal(log2_radix, reset_less=True)
+        comb += pe.i.eq(~pass_flags)
+        with m.If(~pe.n):
+            comb += next_bits.eq(pe.o-1)
+        with m.Else():
+            comb += next_bits.eq(radix-1)
+
+        # get the highest passing rhs trial. use treereduce because
+        # Array on such massively long numbers is insanely gate-hungry
+        crhs = []
+        tcrh = trial_compare_rhs_values
         for i in range(radix):
-            next_flag = pass_flags[i + 1] if (i + 1 < radix) else Const(0)
-            flag = Signal(reset_less=True, name=f"flag{i}")
-            test = Signal(reset_less=True, name=f"test{i}")
-            # XXX TODO: check the width on this
-            m.d.comb += test.eq((pass_flags[i] & ~next_flag))
-            m.d.comb += flag.eq(Mux(test, trial_compare_rhs_values[i], 0))
-            l.append(flag)
-
-        m.d.comb += next_compare_rhs.eq(Cat(*l))
-        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs.bool())
-        m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
-                                                  + ((self.i.divisor_radicand
-                                                      * next_bits)
-                                                     << current_shift))
-        m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
-                                            | (next_bits << current_shift))
+            nbe = Signal(reset_less=True)
+            comb += nbe.eq(next_bits == i)
+            crhs.append(Repl(nbe, self.compare_len) & tcrh[i])
+        comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
+                                      lambda x:x))
+
+        # create outputs for next phase
+        qr = self.i.quotient_root | (next_bits << current_shift)
+        comb += self.o.quotient_root.eq(qr)
+        if DP.RSqrtRem in cc.supported:
+            rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
+                                               next_bits) << current_shift)
+            comb += self.o.root_times_radicand.eq(rr)
+
         return m
 
 
@@ -408,9 +527,9 @@ class DivPipeCoreFinalStage(Elaboratable):
     def elaborate(self, platform):
         """ Elaborate into ``Module``. """
         m = Module()
+        comb = m.d.comb
 
-        m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
-        m.d.comb += self.o.remainder.eq(self.i.compare_lhs
-                                        - self.i.compare_rhs)
+        comb += self.o.quotient_root.eq(self.i.quotient_root)
+        comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
 
         return m