Revert "reduce LHS for RSQRT by a factor of fract_width and"
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
index 1227f11feba1e2bf9a09458b645e64fee2b97878..b62a9574d2684320590859ca126c680e69fcb696 100644 (file)
@@ -18,13 +18,10 @@ Formulas solved are:
 The remainder is the left-hand-side of the comparison minus the
 right-hand-side of the comparison in the above formulas.
 """
-from nmigen import (Elaboratable, Module, Signal)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
+from nmigen.lib.coding import PriorityEncoder
 import enum
 
-# TODO
-#from ieee754.fpcommon.fpbase import FPNumBaseRecord
-#from ieee754.fpcommon.getop import FPPipeContext
-
 
 class DivPipeCoreConfig:
     """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
@@ -47,8 +44,13 @@ class DivPipeCoreConfig:
         return f"DivPipeCoreConfig({self.bit_width}, " \
             + f"{self.fract_width}, {self.log2_radix})"
 
+    @property
+    def n_stages(self):
+        """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
+        return (self.bit_width + self.log2_radix - 1) // self.log2_radix
+
 
-class DivPipeCoreOperation(enum.IntEnum):
+class DivPipeCoreOperation(enum.Enum):
     """ Operation for ``DivPipeCore``.
 
     :attribute UDivRem: unsigned divide/remainder.
@@ -60,16 +62,23 @@ class DivPipeCoreOperation(enum.IntEnum):
     SqrtRem = 1
     RSqrtRem = 2
 
+    def __int__(self):
+        """ Convert to int. """
+        return self.value
+
     @classmethod
     def create_signal(cls, *, src_loc_at=0, **kwargs):
         """ Create a signal that can contain a ``DivPipeCoreOperation``. """
-        return Signal(min=int(min(cls)),
-                      max=int(max(cls)),
+        return Signal(min=min(map(int, cls)),
+                      max=max(map(int, cls)) + 2,
                       src_loc_at=(src_loc_at + 1),
-                      decoder=cls,
+                      decoder=lambda v: str(cls(v)),
                       **kwargs)
 
 
+DP = DivPipeCoreOperation
+
+
 class DivPipeCoreInputData:
     """ input data type for ``DivPipeCore``.
 
@@ -84,43 +93,27 @@ class DivPipeCoreInputData:
     :attribute operation: the ``DivPipeCoreOperation`` to be computed.
     """
 
-    def __init__(self, core_config):
+    def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInputData`` instance. """
         self.core_config = core_config
         self.dividend = Signal(core_config.bit_width + core_config.fract_width,
-                               reset_less=True)
-        self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
-        self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
-
-        return # TODO: needs a width argument and a pspec
-        self.z = FPNumBaseRecord(width, False)
-        self.out_do_z = Signal(reset_less=True)
-        self.oz = Signal(width, reset_less=True)
-
-        self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
-        self.muxid = self.ctx.muxid             # annoying. complicated.
-
+                               reset_less=reset_less)
+        self.divisor_radicand = Signal(core_config.bit_width,
+                                       reset_less=reset_less)
+        self.operation = DP.create_signal(reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
         yield self.dividend
         yield self.divisor_radicand
         yield self.operation
-        return
-        yield self.z
-        yield self.out_do_z
-        yield self.oz
-        yield from self.ctx
 
     def eq(self, rhs):
         """ Assign member signals. """
         return [self.dividend.eq(rhs.dividend),
                 self.divisor_radicand.eq(rhs.divisor_radicand),
-                self.operation.eq(rhs.operation)]
-        # TODO: and these
-        return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
-                self.ctx.eq(i.ctx)]
-
+                self.operation.eq(rhs.operation),
+                ]
 
 
 class DivPipeCoreInterstageData:
@@ -148,23 +141,20 @@ class DivPipeCoreInterstageData:
         ``core_config.fract_width * 3`` bits.
     """
 
-    def __init__(self, core_config):
+    def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInterstageData`` instance. """
         self.core_config = core_config
-        self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
-        self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
-        self.quotient_root = Signal(core_config.bit_width, reset_less=True)
+        self.divisor_radicand = Signal(core_config.bit_width,
+                                       reset_less=reset_less)
+        self.operation = DP.create_signal(reset_less=reset_less)
+        self.quotient_root = Signal(core_config.bit_width,
+                                    reset_less=reset_less)
         self.root_times_radicand = Signal(core_config.bit_width * 2,
-                                          reset_less=True)
-        self.compare_lhs = Signal(core_config.bit_width * 3, reset_less=True)
-        self.compare_rhs = Signal(core_config.bit_width * 3, reset_less=True)
-        return # TODO: needs a width argument and a pspec
-        self.z = FPNumBaseRecord(width, False)
-        self.out_do_z = Signal(reset_less=True)
-        self.oz = Signal(width, reset_less=True)
-
-        self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
-        self.muxid = self.ctx.muxid             # annoying. complicated.
+                                          reset_less=reset_less)
+        self.compare_lhs = Signal(core_config.bit_width * 3,
+                                  reset_less=reset_less)
+        self.compare_rhs = Signal(core_config.bit_width * 3,
+                                  reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
@@ -174,11 +164,6 @@ class DivPipeCoreInterstageData:
         yield self.root_times_radicand
         yield self.compare_lhs
         yield self.compare_rhs
-        return
-        yield self.z
-        yield self.out_do_z
-        yield self.oz
-        yield from self.ctx
 
     def eq(self, rhs):
         """ Assign member signals. """
@@ -188,15 +173,44 @@ class DivPipeCoreInterstageData:
                 self.root_times_radicand.eq(rhs.root_times_radicand),
                 self.compare_lhs.eq(rhs.compare_lhs),
                 self.compare_rhs.eq(rhs.compare_rhs)]
-        # TODO: and these
-        return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
-                self.ctx.eq(i.ctx)]
 
 
-class DivPipeCoreSetupStage(Elaboratable):
-    """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline.
+class DivPipeCoreOutputData:
+    """ output data type for ``DivPipeCore``.
+
+    :attribute core_config: ``DivPipeCoreConfig`` instance describing the
+        configuration to be used.
+    :attribute quotient_root: the quotient or root part of the result of the
+        operation. Signal with a bit-width of ``core_config.bit_width`` and a
+        fract-width of ``core_config.fract_width`` bits.
+    :attribute remainder: the remainder part of the result of the operation.
+        Signal with a bit-width of ``core_config.bit_width * 3`` and a
+        fract-width of ``core_config.fract_width * 3`` bits.
     """
 
+    def __init__(self, core_config, reset_less=True):
+        """ Create a ``DivPipeCoreOutputData`` instance. """
+        self.core_config = core_config
+        self.quotient_root = Signal(core_config.bit_width,
+                                    reset_less=reset_less)
+        self.remainder = Signal(core_config.bit_width * 3,
+                                reset_less=reset_less)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.quotient_root
+        yield self.remainder
+        return
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [self.quotient_root.eq(rhs.quotient_root),
+                self.remainder.eq(rhs.remainder)]
+
+
+class DivPipeCoreSetupStage(Elaboratable):
+    """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
+
     def __init__(self, core_config):
         """ Create a ``DivPipeCoreSetupStage`` instance."""
         self.core_config = core_config
@@ -228,10 +242,10 @@ class DivPipeCoreSetupStage(Elaboratable):
         m.d.comb += self.o.quotient_root.eq(0)
         m.d.comb += self.o.root_times_radicand.eq(0)
 
-        with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+        with m.If(self.i.operation == int(DP.UDivRem)):
             m.d.comb += self.o.compare_lhs.eq(self.i.dividend
                                               << self.core_config.fract_width)
-        with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+        with m.Elif(self.i.operation == int(DP.SqrtRem)):
             m.d.comb += self.o.compare_lhs.eq(
                 self.i.divisor_radicand << (self.core_config.fract_width * 2))
         with m.Else():  # DivPipeCoreOperation.RSqrtRem
@@ -243,8 +257,223 @@ class DivPipeCoreSetupStage(Elaboratable):
 
         return m
 
-        # TODO: these as well
-        m.d.comb += self.o.oz.eq(self.i.oz)
-        m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
-        m.d.comb += self.o.ctx.eq(self.i.ctx)
 
+class Trial(Elaboratable):
+    def __init__(self, core_config, trial_bits, current_shift, log2_radix):
+        self.core_config = core_config
+        self.trial_bits = trial_bits
+        self.current_shift = current_shift
+        self.log2_radix = log2_radix
+        bw = core_config.bit_width
+        self.divisor_radicand = Signal(bw, reset_less=True)
+        self.quotient_root = Signal(bw, reset_less=True)
+        self.root_times_radicand = Signal(bw * 2, reset_less=True)
+        self.compare_rhs = Signal(bw * 3, reset_less=True)
+        self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
+        self.operation = DP.create_signal(reset_less=True)
+
+    def elaborate(self, platform):
+
+        m = Module()
+
+        dr = self.divisor_radicand
+        qr = self.quotient_root
+        rr = self.root_times_radicand
+
+        trial_bits_sig = Const(self.trial_bits, self.log2_radix)
+        trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
+                                    self.log2_radix * 2)
+
+        tblen = self.core_config.bit_width+self.log2_radix
+        tblen2 = self.core_config.bit_width+self.log2_radix*2
+        dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
+        m.d.comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
+
+        # UDivRem
+        with m.If(self.operation == int(DP.UDivRem)):
+            dr_times_trial_bits = Signal(tblen, reset_less=True)
+            m.d.comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
+            div_rhs = self.compare_rhs
+
+            div_term1 = dr_times_trial_bits
+            div_term1_shift = self.core_config.fract_width
+            div_term1_shift += self.current_shift
+            div_rhs += div_term1 << div_term1_shift
+
+            m.d.comb += self.trial_compare_rhs.eq(div_rhs)
+
+        # SqrtRem
+        with m.Elif(self.operation == int(DP.SqrtRem)):
+            qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
+            m.d.comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
+            sqrt_rhs = self.compare_rhs
+
+            sqrt_term1 = qr_times_trial_bits
+            sqrt_term1_shift = self.core_config.fract_width
+            sqrt_term1_shift += self.current_shift + 1
+            sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+            sqrt_term2 = trial_bits_sqrd_sig
+            sqrt_term2_shift = self.core_config.fract_width
+            sqrt_term2_shift += self.current_shift * 2
+            sqrt_rhs += sqrt_term2 << sqrt_term2_shift
+
+            m.d.comb += self.trial_compare_rhs.eq(sqrt_rhs)
+
+        # RSqrtRem
+        with m.Else():
+            rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
+            m.d.comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
+            rsqrt_rhs = self.compare_rhs
+
+            rsqrt_term1 = rr_times_trial_bits
+            rsqrt_term1_shift = self.current_shift + 1
+            rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+            rsqrt_term2 = dr_times_trial_bits_sqrd
+            rsqrt_term2_shift = self.current_shift * 2
+            rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
+
+            m.d.comb += self.trial_compare_rhs.eq(rsqrt_rhs)
+
+        return m
+
+
+class DivPipeCoreCalculateStage(Elaboratable):
+    """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
+
+    def __init__(self, core_config, stage_index):
+        """ Create a ``DivPipeCoreSetupStage`` instance. """
+        self.core_config = core_config
+        assert stage_index in range(core_config.n_stages)
+        self.stage_index = stage_index
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        """ Get the input spec for this pipeline stage. """
+        return DivPipeCoreInterstageData(self.core_config)
+
+    def ospec(self):
+        """ Get the output spec for this pipeline stage. """
+        return DivPipeCoreInterstageData(self.core_config)
+
+    def setup(self, m, i):
+        """ Pipeline stage setup. """
+        setattr(m.submodules,
+                f"div_pipe_core_calculate_{self.stage_index}",
+                self)
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        """ Pipeline stage process. """
+        return self.o
+
+    def elaborate(self, platform):
+        """ Elaborate into ``Module``. """
+        m = Module()
+
+        # copy invariant inputs to outputs (for next stage)
+        m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
+        m.d.comb += self.o.operation.eq(self.i.operation)
+        m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
+
+        # constants
+        log2_radix = self.core_config.log2_radix
+        current_shift = self.core_config.bit_width
+        current_shift -= self.stage_index * log2_radix
+        log2_radix = min(log2_radix, current_shift)
+        assert log2_radix > 0
+        current_shift -= log2_radix
+        radix = 1 << log2_radix
+
+        # trials within this radix range.  carried out by Trial module,
+        # results stored in pass_flags.  pass_flags are unary priority.
+        trial_compare_rhs_values = []
+        pfl = []
+        for trial_bits in range(radix):
+            t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
+            setattr(m.submodules, "trial%d" % trial_bits, t)
+
+            m.d.comb += t.divisor_radicand.eq(self.i.divisor_radicand)
+            m.d.comb += t.quotient_root.eq(self.i.quotient_root)
+            m.d.comb += t.root_times_radicand.eq(self.i.root_times_radicand)
+            m.d.comb += t.compare_rhs.eq(self.i.compare_rhs)
+            m.d.comb += t.operation.eq(self.i.operation)
+
+            # get the trial output
+            trial_compare_rhs_values.append(t.trial_compare_rhs)
+
+            # make the trial comparison against the [invariant] lhs.
+            # trial_compare_rhs is always decreasing as trial_bits increases
+            pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
+            m.d.comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
+            pfl.append(pass_flag)
+
+        # Cat all the pass flags list together (easier to handle, below)
+        pass_flags = Signal(radix, reset_less=True)
+        m.d.comb += pass_flags.eq(Cat(*pfl))
+
+        # convert pass_flags (unary priority) to next_bits (binary index)
+        #
+        # Assumes that for each set bit in pass_flag, all previous bits are
+        # also set.
+        #
+        # Assumes that pass_flag[0] is always set (since
+        # compare_lhs >= compare_rhs is a pipeline invariant).
+
+        m.submodules.pe = pe = PriorityEncoder(radix)
+        next_bits = Signal(log2_radix, reset_less=True)
+        m.d.comb += pe.i.eq(~pass_flags)
+        with m.If(~pe.n):
+            m.d.comb += next_bits.eq(pe.o-1)
+        with m.Else():
+            m.d.comb += next_bits.eq(radix-1)
+
+        # get the highest passing rhs trial (indexed by next_bits)
+        ta = Array(trial_compare_rhs_values)
+        m.d.comb += self.o.compare_rhs.eq(ta[next_bits])
+
+        # create outputs for next phase
+        m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
+                                                  + ((self.i.divisor_radicand
+                                                      * next_bits)
+                                                     << current_shift))
+        m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
+                                            | (next_bits << current_shift))
+        return m
+
+
+class DivPipeCoreFinalStage(Elaboratable):
+    """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
+
+    def __init__(self, core_config):
+        """ Create a ``DivPipeCoreFinalStage`` instance."""
+        self.core_config = core_config
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        """ Get the input spec for this pipeline stage."""
+        return DivPipeCoreInterstageData(self.core_config)
+
+    def ospec(self):
+        """ Get the output spec for this pipeline stage."""
+        return DivPipeCoreOutputData(self.core_config)
+
+    def setup(self, m, i):
+        """ Pipeline stage setup. """
+        m.submodules.div_pipe_core_final = self
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        """ Pipeline stage process. """
+        return self.o  # return processed data (ignore i)
+
+    def elaborate(self, platform):
+        """ Elaborate into ``Module``. """
+        m = Module()
+
+        m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
+        m.d.comb += self.o.remainder.eq(self.i.compare_lhs
+                                        - self.i.compare_rhs)
+
+        return m