From 4a74b6d3bb4dca7d6c855a9b9348d9d40dec788b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 10 Jul 2019 01:01:23 -0700 Subject: [PATCH] DivPipeCore tests pass; still need to add more tests --- src/ieee754/div_rem_sqrt_rsqrt/core.py | 69 ++++++++++++--------- src/ieee754/div_rem_sqrt_rsqrt/test_core.py | 58 +++++++---------- 2 files changed, 62 insertions(+), 65 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/core.py b/src/ieee754/div_rem_sqrt_rsqrt/core.py index 141deb73..e6a0b9b9 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/core.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/core.py @@ -69,7 +69,7 @@ class DivPipeCoreOperation(enum.Enum): def create_signal(cls, *, src_loc_at=0, **kwargs): """ Create a signal that can contain a ``DivPipeCoreOperation``. """ return Signal(min=min(map(int, cls)), - max=max(map(int, cls)), + max=max(map(int, cls)) + 2, src_loc_at=(src_loc_at + 1), decoder=lambda v: str(cls(v)), **kwargs) @@ -305,31 +305,45 @@ class DivPipeCoreCalculateStage(Elaboratable): trial_compare_rhs_values = [] pass_flags = [] for trial_bits in range(radix): - tb = trial_bits << current_shift - tb_width = log2_radix + current_shift - shifted_trial_bits = Const(tb, tb_width) - shifted_trial_bits2 = Const(tb*2, tb_width+1) - shifted_trial_bits_sqrd = Const(tb * tb, tb_width * 2) + trial_bits_sig = Const(trial_bits, log2_radix) + trial_bits_sqrd_sig = Const(trial_bits * trial_bits, + log2_radix * 2) + + dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig + dr_times_trial_bits_sqrd = self.i.divisor_radicand \ + * trial_bits_sqrd_sig + qr_times_trial_bits = self.i.quotient_root * trial_bits_sig + rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig # UDivRem div_rhs = self.i.compare_rhs - if tb != 0: # no point adding stuff that's multiplied by zero - div_factor1 = self.i.divisor_radicand * shifted_trial_bits2 - div_rhs += div_factor1 << self.core_config.fract_width + if trial_bits != 0: # no point adding stuff that's multiplied by zero + div_term1 = dr_times_trial_bits + div_term1_shift = self.core_config.fract_width + div_term1_shift += current_shift + div_rhs += div_term1 << div_term1_shift # SqrtRem sqrt_rhs = self.i.compare_rhs - if tb != 0: # no point adding stuff that's multiplied by zero - sqrt_factor1 = self.i.quotient_root * shifted_trial_bits2 - sqrt_rhs += sqrt_factor1 << self.core_config.fract_width - sqrt_factor2 = shifted_trial_bits_sqrd - sqrt_rhs += sqrt_factor2 << self.core_config.fract_width + if trial_bits != 0: # no point adding stuff that's multiplied by zero + sqrt_term1 = qr_times_trial_bits + sqrt_term1_shift = self.core_config.fract_width + sqrt_term1_shift += current_shift + 1 + sqrt_rhs += sqrt_term1 << sqrt_term1_shift + sqrt_term2 = trial_bits_sqrd_sig + sqrt_term2_shift = self.core_config.fract_width + sqrt_term2_shift += current_shift * 2 + sqrt_rhs += sqrt_term2 << sqrt_term2_shift # RSqrtRem rsqrt_rhs = self.i.compare_rhs - if tb != 0: # no point adding stuff that's multiplied by zero - rsqrt_rhs += self.i.root_times_radicand * shifted_trial_bits2 - rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd + if trial_bits != 0: # no point adding stuff that's multiplied by zero + rsqrt_term1 = rr_times_trial_bits + rsqrt_term1_shift = current_shift + 1 + rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift + rsqrt_term2 = dr_times_trial_bits_sqrd + rsqrt_term2_shift = current_shift * 2 + rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift trial_compare_rhs = Signal.like( self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}", @@ -362,19 +376,16 @@ class DivPipeCoreCalculateStage(Elaboratable): bit_value ^= pass_flags[j] m.d.comb += next_bits.part(i, 1).eq(bit_value) - next_compare_rhs = Signal(radix, reset_less=True) - l = [] + next_compare_rhs = 0 for i in range(radix): - next_flag = pass_flags[i + 1] if (i + 1 < radix) else Const(0) - flag = Signal(reset_less=True, name=f"flag{i}") - test = Signal(reset_less=True, name=f"test{i}") - # XXX TODO: check the width on this - m.d.comb += test.eq((pass_flags[i] & ~next_flag)) - m.d.comb += flag.eq(Mux(test, trial_compare_rhs_values[i], 0)) - l.append(flag) - - m.d.comb += next_compare_rhs.eq(Cat(*l)) - m.d.comb += self.o.compare_rhs.eq(next_compare_rhs.bool()) + next_flag = pass_flags[i + 1] if i + 1 < radix else 0 + selected = Signal(name=f"selected_{i}", reset_less=True) + m.d.comb += selected.eq(pass_flags[i] & ~next_flag) + next_compare_rhs |= Mux(selected, + trial_compare_rhs_values[i], + 0) + + m.d.comb += self.o.compare_rhs.eq(next_compare_rhs) m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand + ((self.i.divisor_radicand * next_bits) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_core.py b/src/ieee754/div_rem_sqrt_rsqrt/test_core.py index 59beadde..a9a3d29f 100755 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_core.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_core.py @@ -75,38 +75,18 @@ class TestCaseData: def generate_test_case(core_config, dividend, divisor_radicand, alg_op): bit_width = core_config.bit_width fract_width = core_config.fract_width - if alg_op is Operation.UDivRem: - if divisor_radicand == 0: - return - quotient_root, remainder = div_rem(dividend, - divisor_radicand, - bit_width * 3, - False) - remainder <<= fract_width - elif alg_op is Operation.SqrtRem: - root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand, - fract_width, - bit_width, - False)) - quotient_root = root_remainder.root.bits - remainder = root_remainder.remainder.bits << fract_width - else: - assert alg_op is Operation.RSqrtRem - if divisor_radicand == 0: - return - root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand, - fract_width, - bit_width, - False)) - quotient_root = root_remainder.root.bits - remainder = root_remainder.remainder.bits - if quotient_root >= (1 << bit_width): - return + obj = FixedUDivRemSqrtRSqrt(dividend, + divisor_radicand, + alg_op, + bit_width, + fract_width, + core_config.log2_radix) + obj.calculate() yield TestCaseData(dividend, divisor_radicand, alg_op, - quotient_root, - remainder, + obj.quotient_root, + obj.remainder, core_config) @@ -145,7 +125,7 @@ def get_test_cases(core_config, class DivPipeCoreTestPipeline(Elaboratable): - def __init__(self, core_config, sync=True): + def __init__(self, core_config, sync): self.setup_stage = DivPipeCoreSetupStage(core_config) self.calculate_stages = [ DivPipeCoreCalculateStage(core_config, stage_index) @@ -202,22 +182,22 @@ class TestDivPipeCore(unittest.TestCase): base_name += f"_fract_width_{core_config.fract_width}" base_name += f"_radix_{1 << core_config.log2_radix}" with self.subTest(part="synthesize"): - dut = DivPipeCoreTestPipeline(core_config) + dut = DivPipeCoreTestPipeline(core_config, sync) vl = rtlil.convert(dut, ports=[*dut.i, *dut.o]) with open(f"{base_name}.il", "w") as f: f.write(vl) - dut = DivPipeCoreTestPipeline(core_config) + dut = DivPipeCoreTestPipeline(core_config, sync) with Simulator(dut, vcd_file=open(f"{base_name}.vcd", "w"), gtkw_file=open(f"{base_name}.gtkw", "w"), traces=[*dut.traces()]) as sim: def generate_process(): for test_case in gen_test_cases(): + yield Tick() yield dut.i.dividend.eq(test_case.dividend) yield dut.i.divisor_radicand.eq(test_case.divisor_radicand) yield dut.i.operation.eq(int(test_case.core_op)) - yield Delay(1e-6) - yield Tick() + yield Delay(0.9e-6) def check_process(): # sync with generator @@ -229,14 +209,14 @@ class TestDivPipeCore(unittest.TestCase): # now synched with generator for test_case in gen_test_cases(): - yield Delay(1e-6) + yield Tick() + yield Delay(0.9e-6) quotient_root = (yield dut.o.quotient_root) remainder = (yield dut.o.remainder) with self.subTest(test_case=str(test_case)): self.assertEqual(quotient_root, test_case.quotient_root) self.assertEqual(remainder, test_case.remainder) - yield Tick() sim.add_clock(2e-6) sim.add_sync_process(generate_process) sim.add_sync_process(check_process) @@ -250,6 +230,12 @@ class TestDivPipeCore(unittest.TestCase): *range(1 << 8, 1 << 12, 1 << 4)], sync=False) + def test_bit_width_2_fract_width_1_radix_2(self): + self.handle_case(DivPipeCoreConfig(bit_width=2, + fract_width=1, + log2_radix=1), + sync=False) + # FIXME: add more test_* functions -- 2.30.2