restore important modifications that seemed to be lost
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
index afc331de6a0fb72de753fda50e98b6776cf77db8..da1be3ac0da24b62bd42d78d82dc6539c5e75de7 100644 (file)
@@ -49,7 +49,7 @@ class DivPipeCoreConfig:
         return (self.bit_width + self.log2_radix - 1) // self.log2_radix
 
 
-class DivPipeCoreOperation(enum.IntEnum):
+class DivPipeCoreOperation(enum.Enum):
     """ Operation for ``DivPipeCore``.
 
     :attribute UDivRem: unsigned divide/remainder.
@@ -61,13 +61,17 @@ class DivPipeCoreOperation(enum.IntEnum):
     SqrtRem = 1
     RSqrtRem = 2
 
+    def __int__(self):
+        """ Convert to int. """
+        return self.value
+
     @classmethod
     def create_signal(cls, *, src_loc_at=0, **kwargs):
         """ Create a signal that can contain a ``DivPipeCoreOperation``. """
-        return Signal(min=int(min(cls)),
-                      max=int(max(cls)),
+        return Signal(min=min(map(int, cls)),
+                      max=max(map(int, cls)) + 2,
                       src_loc_at=(src_loc_at + 1),
-                      decoder=cls,
+                      decoder=lambda v: str(cls(v)),
                       **kwargs)
 
 
@@ -239,10 +243,10 @@ class DivPipeCoreSetupStage(Elaboratable):
         m.d.comb += self.o.quotient_root.eq(0)
         m.d.comb += self.o.root_times_radicand.eq(0)
 
-        with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+        with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
             m.d.comb += self.o.compare_lhs.eq(self.i.dividend
                                               << self.core_config.fract_width)
-        with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+        with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
             m.d.comb += self.o.compare_lhs.eq(
                 self.i.divisor_radicand << (self.core_config.fract_width * 2))
         with m.Else():  # DivPipeCoreOperation.RSqrtRem
@@ -301,41 +305,59 @@ class DivPipeCoreCalculateStage(Elaboratable):
         trial_compare_rhs_values = []
         pass_flags = []
         for trial_bits in range(radix):
-            tb = trial_bits << current_shift
-            log2_tb = log2_radix + current_shift
-            shifted_trial_bits = Const(tb, log2_tb)
-            shifted_trial_bits2 = Const(tb*2, log2_tb+1)
-            shifted_trial_bits_sqrd = Const(tb * tb, log2_tb * 2)
+            trial_bits_sig = Const(trial_bits, log2_radix)
+            trial_bits_sqrd_sig = Const(trial_bits * trial_bits,
+                                        log2_radix * 2)
+
+            dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig
+            dr_times_trial_bits_sqrd = self.i.divisor_radicand \
+                * trial_bits_sqrd_sig
+            qr_times_trial_bits = self.i.quotient_root * trial_bits_sig
+            rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig
 
             # UDivRem
             div_rhs = self.i.compare_rhs
-            div_factor1 = self.i.divisor_radicand * shifted_trial_bits2
-            div_rhs += div_factor1 << self.core_config.fract_width
+            if trial_bits != 0:  # no point adding stuff that's multiplied by zero
+                div_term1 = dr_times_trial_bits
+                div_term1_shift = self.core_config.fract_width
+                div_term1_shift += current_shift
+                div_rhs += div_term1 << div_term1_shift
 
             # SqrtRem
             sqrt_rhs = self.i.compare_rhs
-            sqrt_factor1 = self.i.quotient_root * shifted_trial_bits2
-            sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
-            sqrt_factor2 = shifted_trial_bits_sqrd
-            sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
+            if trial_bits != 0:  # no point adding stuff that's multiplied by zero
+                sqrt_term1 = qr_times_trial_bits
+                sqrt_term1_shift = self.core_config.fract_width
+                sqrt_term1_shift += current_shift + 1
+                sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+                sqrt_term2 = trial_bits_sqrd_sig
+                sqrt_term2_shift = self.core_config.fract_width
+                sqrt_term2_shift += current_shift * 2
+                sqrt_rhs += sqrt_term2 << sqrt_term2_shift
 
             # RSqrtRem
             rsqrt_rhs = self.i.compare_rhs
-            rsqrt_rhs += self.i.root_times_radicand * shifted_trial_bits2
-            rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
+            if trial_bits != 0:  # no point adding stuff that's multiplied by zero
+                rsqrt_term1 = rr_times_trial_bits
+                rsqrt_term1_shift = current_shift + 1
+                rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+                rsqrt_term2 = dr_times_trial_bits_sqrd
+                rsqrt_term2_shift = current_shift * 2
+                rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
 
             trial_compare_rhs = Signal.like(
-                self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}")
+                self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
+                reset_less=True)
 
-            with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+            with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
                 m.d.comb += trial_compare_rhs.eq(div_rhs)
-            with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+            with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
                 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
             with m.Else():  # DivPipeCoreOperation.RSqrtRem
                 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
             trial_compare_rhs_values.append(trial_compare_rhs)
 
-            pass_flag = Signal(name=f"pass_flag_{trial_bits}")
+            pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
             m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
             pass_flags.append(pass_flag)
 
@@ -347,26 +369,26 @@ class DivPipeCoreCalculateStage(Elaboratable):
         # Assumes that pass_flag[0] is always set (since
         # compare_lhs >= compare_rhs is a pipeline invariant).
 
-        next_bits = Signal(log2_radix)
+        next_bits = Signal(log2_radix, reset_less=True)
         for i in range(log2_radix):
             bit_value = 1
             for j in range(0, radix, 1 << i):
                 bit_value ^= pass_flags[j]
             m.d.comb += next_bits.part(i, 1).eq(bit_value)
 
-        next_compare_rhs = Signal(radix, reset_less=True)
-        l = []
+        # XXX using a list to accumulate the bits and then using bool
+        # is IMPORTANT.  if done using |= it results in a chain of OR gates.
+        l = [] # next_compare_rhs
         for i in range(radix):
             next_flag = pass_flags[i + 1] if i + 1 < radix else 0
-            flag = Signal(reset_less=True)
-            test = Signal(reset_less=True)
-            # XXX TODO: check the width on this
-            m.d.comb += test.eq((pass_flags[i] & ~next_flag))
-            m.d.comb += flag.eq(Mux(test, trial_compare_rhs_values[i], 0))
-            l.append(flag)
-
-        m.d.comb += next_compare_rhs.eq(Cat(*l))
-        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs.bool())
+            selected = Signal(name=f"selected_{i}", reset_less=True)
+            m.d.comb += selected.eq(pass_flags[i] & ~next_flag)
+            l.append(Mux(selected, trial_compare_rhs_values[i], 0)
+
+        # concatenate the list of Mux results together and OR them using
+        # the bool operator.
+        m.d.comb += self.o.compare_rhs.eq(Cat(*l).bool())
+
         m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
                                                   + ((self.i.divisor_radicand
                                                       * next_bits)