From 6325dfc3a5045af7153ebd0a09fcd2ecff24f23a Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Tue, 30 Jul 2019 11:29:46 +0100 Subject: [PATCH] use switch/case rather than if/elif/elif --- src/ieee754/div_rem_sqrt_rsqrt/core.py | 102 +++++++++++++------------ 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/core.py b/src/ieee754/div_rem_sqrt_rsqrt/core.py index a070e65d..507caf28 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/core.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/core.py @@ -242,12 +242,13 @@ class DivPipeCoreSetupStage(Elaboratable): lhs = Signal(self.core_config.bit_width * 3, reset_less=True) fw = self.core_config.fract_width - with m.If(self.i.operation == int(DP.UDivRem)): - comb += lhs.eq(self.i.dividend << fw) - with m.Elif(self.i.operation == int(DP.SqrtRem)): - comb += lhs.eq(self.i.divisor_radicand << (fw * 2)) - with m.Elif(self.i.operation == int(DP.RSqrtRem)): - comb += lhs.eq(1 << (fw * 3)) + with m.Switch(self.i.operation): + with m.Case(int(DP.UDivRem)): + comb += lhs.eq(self.i.dividend << fw) + with m.Case(int(DP.SqrtRem)): + comb += lhs.eq(self.i.divisor_radicand << (fw * 2)) + with m.Case(int(DP.RSqrtRem)): + comb += lhs.eq(1 << (fw * 3)) comb += self.o.compare_lhs.eq(lhs) comb += self.o.compare_rhs.eq(0) @@ -288,50 +289,51 @@ class Trial(Elaboratable): dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True) comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig) - # UDivRem - with m.If(self.operation == int(DP.UDivRem)): - dr_times_trial_bits = Signal(tblen, reset_less=True) - comb += dr_times_trial_bits.eq(dr * trial_bits_sig) - div_rhs = self.compare_rhs - - div_term1 = dr_times_trial_bits - div_term1_shift = self.core_config.fract_width - div_term1_shift += self.current_shift - div_rhs += div_term1 << div_term1_shift - - comb += self.trial_compare_rhs.eq(div_rhs) - - # SqrtRem - with m.Elif(self.operation == int(DP.SqrtRem)): - qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True) - comb += qr_times_trial_bits.eq(qr * trial_bits_sig) - sqrt_rhs = self.compare_rhs - - sqrt_term1 = qr_times_trial_bits - sqrt_term1_shift = self.core_config.fract_width - sqrt_term1_shift += self.current_shift + 1 - sqrt_rhs += sqrt_term1 << sqrt_term1_shift - sqrt_term2 = trial_bits_sqrd_sig - sqrt_term2_shift = self.core_config.fract_width - sqrt_term2_shift += self.current_shift * 2 - sqrt_rhs += sqrt_term2 << sqrt_term2_shift - - comb += self.trial_compare_rhs.eq(sqrt_rhs) - - # RSqrtRem - with m.Elif(self.operation == int(DP.RSqrtRem)): - rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True) - comb += rr_times_trial_bits.eq(rr * trial_bits_sig) - rsqrt_rhs = self.compare_rhs - - rsqrt_term1 = rr_times_trial_bits - rsqrt_term1_shift = self.current_shift + 1 - rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift - rsqrt_term2 = dr_times_trial_bits_sqrd - rsqrt_term2_shift = self.current_shift * 2 - rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift - - comb += self.trial_compare_rhs.eq(rsqrt_rhs) + with m.Switch(self.operation): + # UDivRem + with m.Case(int(DP.UDivRem)): + dr_times_trial_bits = Signal(tblen, reset_less=True) + comb += dr_times_trial_bits.eq(dr * trial_bits_sig) + div_rhs = self.compare_rhs + + div_term1 = dr_times_trial_bits + div_term1_shift = self.core_config.fract_width + div_term1_shift += self.current_shift + div_rhs += div_term1 << div_term1_shift + + comb += self.trial_compare_rhs.eq(div_rhs) + + # SqrtRem + with m.Case(int(DP.SqrtRem)): + qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True) + comb += qr_times_trial_bits.eq(qr * trial_bits_sig) + sqrt_rhs = self.compare_rhs + + sqrt_term1 = qr_times_trial_bits + sqrt_term1_shift = self.core_config.fract_width + sqrt_term1_shift += self.current_shift + 1 + sqrt_rhs += sqrt_term1 << sqrt_term1_shift + sqrt_term2 = trial_bits_sqrd_sig + sqrt_term2_shift = self.core_config.fract_width + sqrt_term2_shift += self.current_shift * 2 + sqrt_rhs += sqrt_term2 << sqrt_term2_shift + + comb += self.trial_compare_rhs.eq(sqrt_rhs) + + # RSqrtRem + with m.Case(int(DP.RSqrtRem)): + rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True) + comb += rr_times_trial_bits.eq(rr * trial_bits_sig) + rsqrt_rhs = self.compare_rhs + + rsqrt_term1 = rr_times_trial_bits + rsqrt_term1_shift = self.current_shift + 1 + rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift + rsqrt_term2 = dr_times_trial_bits_sqrd + rsqrt_term2_shift = self.current_shift * 2 + rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift + + comb += self.trial_compare_rhs.eq(rsqrt_rhs) return m -- 2.30.2