From 269f8b1aaefb682e2d0c57729b357cd30919e612 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 5 Jul 2019 05:01:40 -0700 Subject: [PATCH] add rest of DivPipeCore --- src/ieee754/div_rem_sqrt_rsqrt/core.py | 208 +++++++++++++++++++++++-- 1 file changed, 195 insertions(+), 13 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/core.py b/src/ieee754/div_rem_sqrt_rsqrt/core.py index dd1fdf91..b52c8948 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/core.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/core.py @@ -18,7 +18,7 @@ Formulas solved are: The remainder is the left-hand-side of the comparison minus the right-hand-side of the comparison in the above formulas. """ -from nmigen import (Elaboratable, Module, Signal) +from nmigen import (Elaboratable, Module, Signal, Const, Mux) import enum # TODO @@ -47,6 +47,11 @@ class DivPipeCoreConfig: return f"DivPipeCoreConfig({self.bit_width}, " \ + f"{self.fract_width}, {self.log2_radix})" + @property + def num_calculate_stages(self): + """ Get the number of ``DivPipeCoreCalculateStage`` needed. """ + return (self.bit_width + self.log2_radix - 1) // self.log2_radix + class DivPipeCoreOperation(enum.IntEnum): """ Operation for ``DivPipeCore``. @@ -94,20 +99,19 @@ class DivPipeCoreInputData: # FIXME: this goes into (is replaced by) self.ctx.op self.operation = DivPipeCoreOperation.create_signal(reset_less=True) - return # TODO: needs a width argument and a pspec + return # TODO: needs a width argument and a pspec self.z = FPNumBaseRecord(width, False) self.out_do_z = Signal(reset_less=True) self.oz = Signal(width, reset_less=True) - self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc. + self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc. self.muxid = self.ctx.muxid # annoying. complicated. - def __iter__(self): """ Get member signals. """ yield self.dividend yield self.divisor_radicand - yield self.operation # FIXME: delete. already covered by self.ctx + yield self.operation # FIXME: delete. already covered by self.ctx return yield self.z yield self.out_do_z @@ -118,13 +122,12 @@ class DivPipeCoreInputData: """ Assign member signals. """ return [self.dividend.eq(rhs.dividend), self.divisor_radicand.eq(rhs.divisor_radicand), - self.operation.eq(rhs.operation)] # FIXME: delete. + self.operation.eq(rhs.operation)] # FIXME: delete. # TODO: and these return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz), self.ctx.eq(i.ctx)] - class DivPipeCoreInterstageData: """ interstage data type for ``DivPipeCore``. @@ -161,18 +164,18 @@ class DivPipeCoreInterstageData: reset_less=True) self.compare_lhs = Signal(core_config.bit_width * 3, reset_less=True) self.compare_rhs = Signal(core_config.bit_width * 3, reset_less=True) - return # TODO: needs a width argument and a pspec + return # TODO: needs a width argument and a pspec self.z = FPNumBaseRecord(width, False) self.out_do_z = Signal(reset_less=True) self.oz = Signal(width, reset_less=True) - self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc. + self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc. self.muxid = self.ctx.muxid # annoying. complicated. def __iter__(self): """ Get member signals. """ yield self.divisor_radicand - yield self.operation # XXX FIXME: delete. already in self.ctx.op + yield self.operation # XXX FIXME: delete. already in self.ctx.op yield self.quotient_root yield self.root_times_radicand yield self.compare_lhs @@ -186,7 +189,7 @@ class DivPipeCoreInterstageData: def eq(self, rhs): """ Assign member signals. """ return [self.divisor_radicand.eq(rhs.divisor_radicand), - self.operation.eq(rhs.operation), # FIXME: delete. + self.operation.eq(rhs.operation), # FIXME: delete. self.quotient_root.eq(rhs.quotient_root), self.root_times_radicand.eq(rhs.root_times_radicand), self.compare_lhs.eq(rhs.compare_lhs), @@ -196,10 +199,40 @@ class DivPipeCoreInterstageData: self.ctx.eq(i.ctx)] -class DivPipeCoreSetupStage(Elaboratable): - """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. +class DivPipeCoreOutputData: + """ output data type for ``DivPipeCore``. + + :attribute core_config: ``DivPipeCoreConfig`` instance describing the + configuration to be used. + :attribute quotient_root: the quotient or root part of the result of the + operation. Signal with a bit-width of ``core_config.bit_width`` and a + fract-width of ``core_config.fract_width`` bits. + :attribute remainder: the remainder part of the result of the operation. + Signal with a bit-width of ``core_config.bit_width * 3`` and a + fract-width of ``core_config.fract_width * 3`` bits. """ + def __init__(self, core_config): + """ Create a ``DivPipeCoreOutputData`` instance. """ + self.core_config = core_config + self.quotient_root = Signal(core_config.bit_width, reset_less=True) + self.remainder = Signal(core_config.bit_width * 3, reset_less=True) + + def __iter__(self): + """ Get member signals. """ + yield self.quotient_root + yield self.remainder + return + + def eq(self, rhs): + """ Assign member signals. """ + return [self.quotient_root.eq(rhs.quotient_root), + self.remainder.eq(rhs.remainder)] + + +class DivPipeCoreSetupStage(Elaboratable): + """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """ + def __init__(self, core_config): """ Create a ``DivPipeCoreSetupStage`` instance.""" self.core_config = core_config @@ -251,3 +284,152 @@ class DivPipeCoreSetupStage(Elaboratable): m.d.comb += self.o.out_do_z.eq(self.i.out_do_z) m.d.comb += self.o.ctx.eq(self.i.ctx) + +class DivPipeCoreCalculateStage(Elaboratable): + """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """ + + def __init__(self, core_config, stage_index): + """ Create a ``DivPipeCoreSetupStage`` instance. """ + self.core_config = core_config + assert stage_index in range(core_config.num_calculate_stages) + self.stage_index = stage_index + self.i = self.ispec() + self.o = self.ospec() + + def ispec(self): + """ Get the input spec for this pipeline stage. """ + return DivPipeCoreInterstageData(self.core_config) + + def ospec(self): + """ Get the output spec for this pipeline stage. """ + return DivPipeCoreInterstageData(self.core_config) + + def setup(self, m, i): + """ Pipeline stage setup. """ + setattr(m.submodules, + f"div_pipe_core_calculate_{self.stage_index}", + self) + m.d.comb += self.i.eq(i) + + def process(self, i): + """ Pipeline stage process. """ + return self.o + + def elaborate(self, platform): + """ Elaborate into ``Module``. """ + m = Module() + m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand) + m.d.comb += self.o.operation.eq(self.i.operation) + m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs) + log2_radix = self.core_config.log2_radix + current_shift = self.core_config.bit_width + current_shift -= self.stage_index * log2_radix + log2_radix = min(log2_radix, current_shift) + assert log2_radix > 0 + current_shift -= log2_radix + radix = 1 << log2_radix + trial_compare_rhs_values = [] + pass_flags = [] + for trial_bits in range(radix): + shifted_trial_bits = Const(trial_bits, log2_radix) << current_shift + shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits + + # UDivRem + div_rhs = self.i.compare_rhs + div_factor1 = self.i.divisor_radicand * shifted_trial_bits + div_rhs += div_factor1 << self.core_config.fract_width + + # SqrtRem + sqrt_rhs = self.i.compare_rhs + sqrt_factor1 = self.i.quotient_root * (shifted_trial_bits << 1) + sqrt_rhs += sqrt_factor1 << self.core_config.fract_width + sqrt_factor2 = shifted_trial_bits_sqrd + sqrt_rhs += sqrt_factor2 << self.core_config.fract_width + + # RSqrtRem + rsqrt_rhs = self.i.compare_rhs + rsqrt_rhs += self.i.root_times_radicand * (shifted_trial_bits << 1) + rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd + + trial_compare_rhs = self.o.compare_rhs.like( + name=f"trial_compare_rhs_{trial_bits}") + + with m.If(self.i.operation == DivPipeCoreOperation.UDivRem): + m.d.comb += trial_compare_rhs.eq(div_rhs) + with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem): + m.d.comb += trial_compare_rhs.eq(sqrt_rhs) + with m.Else(): # DivPipeCoreOperation.RSqrtRem + m.d.comb += trial_compare_rhs.eq(rsqrt_rhs) + trial_compare_rhs_values.append(trial_compare_rhs) + + pass_flag = Signal(name=f"pass_flag_{trial_bits}") + m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs) + pass_flags.append(pass_flag) + + # convert pass_flags to next_bits. + # + # Assumes that for each set bit in pass_flag, all previous bits are + # also set. + # + # Assumes that pass_flag[0] is always set (since + # compare_lhs >= compare_rhs is a pipeline invariant). + + next_bits = Signal(log2_radix) + for i in range(log2_radix): + bit_value = 1 + for j in range(0, radix, 1 << i): + bit_value ^= pass_flags[j] + m.d.comb += next_bits.part(i, 1).eq(bit_value) + + next_compare_rhs = 0 + for i in range(radix): + next_flag = pass_flags[i + 1] if i + 1 < radix else 0 + next_compare_rhs |= Mux(pass_flags[i] & ~next_flag, + trial_compare_rhs_values[i], + 0) + + m.d.comb += self.o.compare_rhs.eq(next_compare_rhs) + m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand + + ((self.i.divisor_radicand + * next_bits) + << current_shift)) + m.d.comb += self.o.quotient_root.eq(self.i.quotient_root + | (next_bits << current_shift)) + return m + + +class DivPipeCoreFinalStage(Elaboratable): + """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """ + + def __init__(self, core_config): + """ Create a ``DivPipeCoreFinalStage`` instance.""" + self.core_config = core_config + self.i = self.ispec() + self.o = self.ospec() + + def ispec(self): + """ Get the input spec for this pipeline stage.""" + return DivPipeCoreInterstageData(self.core_config) + + def ospec(self): + """ Get the output spec for this pipeline stage.""" + return DivPipeCoreOutputData(self.core_config) + + def setup(self, m, i): + """ Pipeline stage setup. """ + m.submodules.div_pipe_core_setup = self + m.d.comb += self.i.eq(i) + + def process(self, i): + """ Pipeline stage process. """ + return self.o # return processed data (ignore i) + + def elaborate(self, platform): + """ Elaborate into ``Module``. """ + m = Module() + + m.d.comb += self.o.quotient_root.eq(self.i.quotient_root) + m.d.comb += self.o.remainder.eq(self.i.compare_lhs + - self.i.compare_rhs) + + return m -- 2.30.2