reduce next_bits by 1
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
index c41bb93ab00971ee713db61897b488e8ee230dc5..133216986d9d16ea76b966a044a075a8fc8f1d3d 100644 (file)
@@ -18,7 +18,8 @@ Formulas solved are:
 The remainder is the left-hand-side of the comparison minus the
 right-hand-side of the comparison in the above formulas.
 """
-from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
+from nmigen.lib.coding import PriorityEncoder
 import enum
 
 
@@ -368,9 +369,13 @@ class DivPipeCoreCalculateStage(Elaboratable):
     def elaborate(self, platform):
         """ Elaborate into ``Module``. """
         m = Module()
+
+        # copy invariant inputs to outputs (for next stage)
         m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
         m.d.comb += self.o.operation.eq(self.i.operation)
         m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
+
+        # constants
         log2_radix = self.core_config.log2_radix
         current_shift = self.core_config.bit_width
         current_shift -= self.stage_index * log2_radix
@@ -378,12 +383,15 @@ class DivPipeCoreCalculateStage(Elaboratable):
         assert log2_radix > 0
         current_shift -= log2_radix
         radix = 1 << log2_radix
+
+        # trials within this radix range.  carried out by Trial module,
+        # results stored in pass_flags.  pass_flags are unary priority.
         trial_compare_rhs_values = []
         pfl = []
         for trial_bits in range(radix):
-            t = Trial(self.core_config, trial_bits,
-                          current_shift, log2_radix)
+            t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
             setattr(m.submodules, "trial%d" % trial_bits, t)
+
             m.d.comb += t.divisor_radicand.eq(self.i.divisor_radicand)
             m.d.comb += t.quotient_root.eq(self.i.quotient_root)
             m.d.comb += t.root_times_radicand.eq(self.i.root_times_radicand)
@@ -395,10 +403,11 @@ class DivPipeCoreCalculateStage(Elaboratable):
             pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
             m.d.comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
             pfl.append(pass_flag)
+
         pass_flags = Signal(radix, reset_less=True)
         m.d.comb += pass_flags.eq(Cat(*pfl))
 
-        # convert pass_flags to next_bits.
+        # convert pass_flags (unary priority) to next_bits (binary index)
         #
         # Assumes that for each set bit in pass_flag, all previous bits are
         # also set.
@@ -406,30 +415,19 @@ class DivPipeCoreCalculateStage(Elaboratable):
         # Assumes that pass_flag[0] is always set (since
         # compare_lhs >= compare_rhs is a pipeline invariant).
 
+        m.submodules.pe = pe = PriorityEncoder(radix)
         next_bits = Signal(log2_radix, reset_less=True)
-        l = []
-        for i in range(log2_radix):
-            bit_value = 1
-            for j in range(0, radix, 1 << i):
-                bit_value ^= pass_flags[j]
-            bv = Signal(reset_less=True)
-            m.d.comb += bv.eq(bit_value)
-            l.append(bv)
-        m.d.comb += next_bits.eq(Cat(*l))
-
-        # merge/select multi-bit trial_compare_rhs_values, to go
-        # into compare_rhs. XXX (only one of these will succeed?)
-        next_compare_rhs = 0
-        for i in range(radix):
-            next_flag = Signal(name=f"next_flag{i}", reset_less=True)
-            selected = Signal(name=f"selected_{i}", reset_less=True)
-            m.d.comb += next_flag.eq(~pass_flags[i + 1] if i + 1 < radix else 1)
-            m.d.comb += selected.eq(pass_flags[i] & next_flag)
-            next_compare_rhs |= Mux(selected,
-                                    trial_compare_rhs_values[i],
-                                    0)
-
-        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
+        m.d.comb += pe.i.eq(~pass_flags)
+        with m.If(~pe.n):
+            m.d.comb += next_bits.eq(pe.o-1)
+        with m.Else():
+            m.d.comb += next_bits.eq(radix-1)
+
+        # get the highest passing rhs trial (indexed by next_bits)
+        ta = Array(trial_compare_rhs_values)
+        m.d.comb += self.o.compare_rhs.eq(ta[next_bits])
+
+        # create outputs for next phase
         m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
                                                   + ((self.i.divisor_radicand
                                                       * next_bits)