return m
+class Trial(Elaboratable):
+ def __init__(self, core_config, trial_bits, current_shift, log2_radix):
+ self.core_config = core_config
+ self.trial_bits = trial_bits
+ self.current_shift = current_shift
+ self.log2_radix = log2_radix
+ bw = core_config.bit_width
+ self.divisor_radicand = Signal(bw, reset_less=True)
+ self.quotient_root = Signal(bw, reset_less=True)
+ self.root_times_radicand = Signal(bw * 2, reset_less=True)
+ self.compare_rhs = Signal(bw * 3, reset_less=True)
+ self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
+ self.operation = DP.create_signal(reset_less=True)
+
+ def elaborate(self, platform):
+
+ m = Module()
+
+ dr = self.divisor_radicand
+ qr = self.quotient_root
+ rr = self.root_times_radicand
+
+ trial_bits_sig = Const(self.trial_bits, self.log2_radix)
+ trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
+ self.log2_radix * 2)
+
+ tblen = self.core_config.bit_width+self.log2_radix
+ tblen2 = self.core_config.bit_width+self.log2_radix*2
+ dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
+ m.d.comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
+
+ # UDivRem
+ with m.If(self.operation == int(DP.UDivRem)):
+ dr_times_trial_bits = Signal(tblen, reset_less=True)
+ m.d.comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
+ div_rhs = self.compare_rhs
+
+ div_term1 = dr_times_trial_bits
+ div_term1_shift = self.core_config.fract_width
+ div_term1_shift += self.current_shift
+ div_rhs += div_term1 << div_term1_shift
+
+ m.d.comb += self.trial_compare_rhs.eq(div_rhs)
+
+ # SqrtRem
+ with m.Elif(self.operation == int(DP.SqrtRem)):
+ qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
+ m.d.comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
+ sqrt_rhs = self.compare_rhs
+
+ sqrt_term1 = qr_times_trial_bits
+ sqrt_term1_shift = self.core_config.fract_width
+ sqrt_term1_shift += self.current_shift + 1
+ sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+ sqrt_term2 = trial_bits_sqrd_sig
+ sqrt_term2_shift = self.core_config.fract_width
+ sqrt_term2_shift += self.current_shift * 2
+ sqrt_rhs += sqrt_term2 << sqrt_term2_shift
+
+ m.d.comb += self.trial_compare_rhs.eq(sqrt_rhs)
+
+ # RSqrtRem
+ with m.Else():
+ rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
+ m.d.comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
+ rsqrt_rhs = self.compare_rhs
+
+ rsqrt_term1 = rr_times_trial_bits
+ rsqrt_term1_shift = self.current_shift + 1
+ rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+ rsqrt_term2 = dr_times_trial_bits_sqrd
+ rsqrt_term2_shift = self.current_shift * 2
+ rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
+
+ m.d.comb += self.trial_compare_rhs.eq(rsqrt_rhs)
+
+ return m
+
+
class DivPipeCoreCalculateStage(Elaboratable):
""" Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
trial_compare_rhs_values = []
pass_flags = []
for trial_bits in range(radix):
- trial_bits_sig = Const(trial_bits, log2_radix)
- trial_bits_sqrd_sig = Const(trial_bits * trial_bits,
- log2_radix * 2)
-
- dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig
- dr_times_trial_bits_sqrd = self.i.divisor_radicand \
- * trial_bits_sqrd_sig
- qr_times_trial_bits = self.i.quotient_root * trial_bits_sig
- rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig
-
- trial_compare_rhs = Signal.like(
- self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
- reset_less=True)
- m.d.comb += trial_compare_rhs.eq(self.i.compare_rhs)
-
- if trial_bits != 0: # no point adding multiply by zero
- # UDivRem
- with m.If(self.i.operation == int(DP.UDivRem)):
- div_rhs = self.i.compare_rhs
-
- div_term1 = dr_times_trial_bits
- div_term1_shift = self.core_config.fract_width
- div_term1_shift += current_shift
- div_rhs += div_term1 << div_term1_shift
-
- m.d.comb += trial_compare_rhs.eq(div_rhs)
-
- # SqrtRem
- with m.Elif(self.i.operation == int(DP.SqrtRem)):
- sqrt_rhs = self.i.compare_rhs
-
- sqrt_term1 = qr_times_trial_bits
- sqrt_term1_shift = self.core_config.fract_width
- sqrt_term1_shift += current_shift + 1
- sqrt_rhs += sqrt_term1 << sqrt_term1_shift
- sqrt_term2 = trial_bits_sqrd_sig
- sqrt_term2_shift = self.core_config.fract_width
- sqrt_term2_shift += current_shift * 2
- sqrt_rhs += sqrt_term2 << sqrt_term2_shift
-
- m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
-
- # RSqrtRem
- with m.Else():
- rsqrt_rhs = self.i.compare_rhs
-
- rsqrt_term1 = rr_times_trial_bits
- rsqrt_term1_shift = current_shift + 1
- rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
- rsqrt_term2 = dr_times_trial_bits_sqrd
- rsqrt_term2_shift = current_shift * 2
- rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
-
- m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
-
- trial_compare_rhs_values.append(trial_compare_rhs)
+ t = Trial(self.core_config, trial_bits,
+ current_shift, log2_radix)
+ setattr(m.submodules, "trial%d" % trial_bits, t)
+ m.d.comb += t.divisor_radicand.eq(self.i.divisor_radicand)
+ m.d.comb += t.quotient_root.eq(self.i.quotient_root)
+ m.d.comb += t.root_times_radicand.eq(self.i.root_times_radicand)
+ m.d.comb += t.compare_rhs.eq(self.i.compare_rhs)
+ m.d.comb += t.operation.eq(self.i.operation)
+
+ trial_compare_rhs_values.append(t.trial_compare_rhs)
pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
- m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
+ m.d.comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
pass_flags.append(pass_flag)
# convert pass_flags to next_bits.