DivPipeCore tests pass; still need to add more tests
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 10 Jul 2019 08:01:23 +0000 (01:01 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 10 Jul 2019 08:01:23 +0000 (01:01 -0700)
src/ieee754/div_rem_sqrt_rsqrt/core.py
src/ieee754/div_rem_sqrt_rsqrt/test_core.py

index 141deb7365456d727952588c6e8c848c91652625..e6a0b9b9d848d93cbef2b35b371befb2d946d7a2 100644 (file)
@@ -69,7 +69,7 @@ class DivPipeCoreOperation(enum.Enum):
     def create_signal(cls, *, src_loc_at=0, **kwargs):
         """ Create a signal that can contain a ``DivPipeCoreOperation``. """
         return Signal(min=min(map(int, cls)),
-                      max=max(map(int, cls)),
+                      max=max(map(int, cls)) + 2,
                       src_loc_at=(src_loc_at + 1),
                       decoder=lambda v: str(cls(v)),
                       **kwargs)
@@ -305,31 +305,45 @@ class DivPipeCoreCalculateStage(Elaboratable):
         trial_compare_rhs_values = []
         pass_flags = []
         for trial_bits in range(radix):
-            tb = trial_bits << current_shift
-            tb_width = log2_radix + current_shift
-            shifted_trial_bits = Const(tb, tb_width)
-            shifted_trial_bits2 = Const(tb*2, tb_width+1)
-            shifted_trial_bits_sqrd = Const(tb * tb, tb_width * 2)
+            trial_bits_sig = Const(trial_bits, log2_radix)
+            trial_bits_sqrd_sig = Const(trial_bits * trial_bits,
+                                        log2_radix * 2)
+
+            dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig
+            dr_times_trial_bits_sqrd = self.i.divisor_radicand \
+                * trial_bits_sqrd_sig
+            qr_times_trial_bits = self.i.quotient_root * trial_bits_sig
+            rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig
 
             # UDivRem
             div_rhs = self.i.compare_rhs
-            if tb != 0:  # no point adding stuff that's multiplied by zero
-                div_factor1 = self.i.divisor_radicand * shifted_trial_bits2
-                div_rhs += div_factor1 << self.core_config.fract_width
+            if trial_bits != 0:  # no point adding stuff that's multiplied by zero
+                div_term1 = dr_times_trial_bits
+                div_term1_shift = self.core_config.fract_width
+                div_term1_shift += current_shift
+                div_rhs += div_term1 << div_term1_shift
 
             # SqrtRem
             sqrt_rhs = self.i.compare_rhs
-            if tb != 0:  # no point adding stuff that's multiplied by zero
-                sqrt_factor1 = self.i.quotient_root * shifted_trial_bits2
-                sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
-                sqrt_factor2 = shifted_trial_bits_sqrd
-                sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
+            if trial_bits != 0:  # no point adding stuff that's multiplied by zero
+                sqrt_term1 = qr_times_trial_bits
+                sqrt_term1_shift = self.core_config.fract_width
+                sqrt_term1_shift += current_shift + 1
+                sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+                sqrt_term2 = trial_bits_sqrd_sig
+                sqrt_term2_shift = self.core_config.fract_width
+                sqrt_term2_shift += current_shift * 2
+                sqrt_rhs += sqrt_term2 << sqrt_term2_shift
 
             # RSqrtRem
             rsqrt_rhs = self.i.compare_rhs
-            if tb != 0:  # no point adding stuff that's multiplied by zero
-                rsqrt_rhs += self.i.root_times_radicand * shifted_trial_bits2
-                rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
+            if trial_bits != 0:  # no point adding stuff that's multiplied by zero
+                rsqrt_term1 = rr_times_trial_bits
+                rsqrt_term1_shift = current_shift + 1
+                rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+                rsqrt_term2 = dr_times_trial_bits_sqrd
+                rsqrt_term2_shift = current_shift * 2
+                rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
 
             trial_compare_rhs = Signal.like(
                 self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
@@ -362,19 +376,16 @@ class DivPipeCoreCalculateStage(Elaboratable):
                 bit_value ^= pass_flags[j]
             m.d.comb += next_bits.part(i, 1).eq(bit_value)
 
-        next_compare_rhs = Signal(radix, reset_less=True)
-        l = []
+        next_compare_rhs = 0
         for i in range(radix):
-            next_flag = pass_flags[i + 1] if (i + 1 < radix) else Const(0)
-            flag = Signal(reset_less=True, name=f"flag{i}")
-            test = Signal(reset_less=True, name=f"test{i}")
-            # XXX TODO: check the width on this
-            m.d.comb += test.eq((pass_flags[i] & ~next_flag))
-            m.d.comb += flag.eq(Mux(test, trial_compare_rhs_values[i], 0))
-            l.append(flag)
-
-        m.d.comb += next_compare_rhs.eq(Cat(*l))
-        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs.bool())
+            next_flag = pass_flags[i + 1] if i + 1 < radix else 0
+            selected = Signal(name=f"selected_{i}", reset_less=True)
+            m.d.comb += selected.eq(pass_flags[i] & ~next_flag)
+            next_compare_rhs |= Mux(selected,
+                                    trial_compare_rhs_values[i],
+                                    0)
+
+        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
         m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
                                                   + ((self.i.divisor_radicand
                                                       * next_bits)
index 59beaddebb10e37688137d132ad25b24726219de..a9a3d29fed42cfe5d37665b1346dbf1dfc4e8925 100755 (executable)
@@ -75,38 +75,18 @@ class TestCaseData:
 def generate_test_case(core_config, dividend, divisor_radicand, alg_op):
     bit_width = core_config.bit_width
     fract_width = core_config.fract_width
-    if alg_op is Operation.UDivRem:
-        if divisor_radicand == 0:
-            return
-        quotient_root, remainder = div_rem(dividend,
-                                           divisor_radicand,
-                                           bit_width * 3,
-                                           False)
-        remainder <<= fract_width
-    elif alg_op is Operation.SqrtRem:
-        root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
-                                                    fract_width,
-                                                    bit_width,
-                                                    False))
-        quotient_root = root_remainder.root.bits
-        remainder = root_remainder.remainder.bits << fract_width
-    else:
-        assert alg_op is Operation.RSqrtRem
-        if divisor_radicand == 0:
-            return
-        root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
-                                                     fract_width,
-                                                     bit_width,
-                                                     False))
-        quotient_root = root_remainder.root.bits
-        remainder = root_remainder.remainder.bits
-    if quotient_root >= (1 << bit_width):
-        return
+    obj = FixedUDivRemSqrtRSqrt(dividend,
+                                divisor_radicand,
+                                alg_op,
+                                bit_width,
+                                fract_width,
+                                core_config.log2_radix)
+    obj.calculate()
     yield TestCaseData(dividend,
                        divisor_radicand,
                        alg_op,
-                       quotient_root,
-                       remainder,
+                       obj.quotient_root,
+                       obj.remainder,
                        core_config)
 
 
@@ -145,7 +125,7 @@ def get_test_cases(core_config,
 
 
 class DivPipeCoreTestPipeline(Elaboratable):
-    def __init__(self, core_config, sync=True):
+    def __init__(self, core_config, sync):
         self.setup_stage = DivPipeCoreSetupStage(core_config)
         self.calculate_stages = [
             DivPipeCoreCalculateStage(core_config, stage_index)
@@ -202,22 +182,22 @@ class TestDivPipeCore(unittest.TestCase):
         base_name += f"_fract_width_{core_config.fract_width}"
         base_name += f"_radix_{1 << core_config.log2_radix}"
         with self.subTest(part="synthesize"):
-            dut = DivPipeCoreTestPipeline(core_config)
+            dut = DivPipeCoreTestPipeline(core_config, sync)
             vl = rtlil.convert(dut, ports=[*dut.i, *dut.o])
             with open(f"{base_name}.il", "w") as f:
                 f.write(vl)
-        dut = DivPipeCoreTestPipeline(core_config)
+        dut = DivPipeCoreTestPipeline(core_config, sync)
         with Simulator(dut,
                        vcd_file=open(f"{base_name}.vcd", "w"),
                        gtkw_file=open(f"{base_name}.gtkw", "w"),
                        traces=[*dut.traces()]) as sim:
             def generate_process():
                 for test_case in gen_test_cases():
+                    yield Tick()
                     yield dut.i.dividend.eq(test_case.dividend)
                     yield dut.i.divisor_radicand.eq(test_case.divisor_radicand)
                     yield dut.i.operation.eq(int(test_case.core_op))
-                    yield Delay(1e-6)
-                    yield Tick()
+                    yield Delay(0.9e-6)
 
             def check_process():
                 # sync with generator
@@ -229,14 +209,14 @@ class TestDivPipeCore(unittest.TestCase):
 
                 # now synched with generator
                 for test_case in gen_test_cases():
-                    yield Delay(1e-6)
+                    yield Tick()
+                    yield Delay(0.9e-6)
                     quotient_root = (yield dut.o.quotient_root)
                     remainder = (yield dut.o.remainder)
                     with self.subTest(test_case=str(test_case)):
                         self.assertEqual(quotient_root,
                                          test_case.quotient_root)
                         self.assertEqual(remainder, test_case.remainder)
-                    yield Tick()
             sim.add_clock(2e-6)
             sim.add_sync_process(generate_process)
             sim.add_sync_process(check_process)
@@ -250,6 +230,12 @@ class TestDivPipeCore(unittest.TestCase):
                                     *range(1 << 8, 1 << 12, 1 << 4)],
                          sync=False)
 
+    def test_bit_width_2_fract_width_1_radix_2(self):
+        self.handle_case(DivPipeCoreConfig(bit_width=2,
+                                           fract_width=1,
+                                           log2_radix=1),
+                         sync=False)
+
     # FIXME: add more test_* functions