X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fieee754%2Fdiv_rem_sqrt_rsqrt%2Falgorithm.py;h=84ea1d4c78965778529854f2f35bed9cb344b0a0;hb=314c4130602c897d99b7ca9b3af1332df1070bee;hp=7fa051f2fdcfed335e04580d713a61199cc63ead;hpb=397136bfba5b2888b050aac4448f3aa22b0d8709;p=ieee754fpu.git diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index 7fa051f2..84ea1d4c 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -8,6 +8,7 @@ code for simulating/testing the various algorithms from nmigen.hdl.ast import Const import math +import enum def div_rem(dividend, divisor, bit_width, signed): @@ -37,12 +38,14 @@ class UnsignedDivRem: NOT the same as the // or % operators - :attribute remainder: the remainder and/or dividend + :attribute dividend: the dividend + :attribute remainder: the remainder :attribute divisor: the divisor :attribute bit_width: the bit width of the inputs/outputs :attribute log2_radix: the base-2 log of the division radix. The number of bits of quotient that are calculated per pipeline stage. :attribute quotient: the quotient + :attribute quotient_times_divisor: ``quotient * divisor`` :attribute current_shift: the current bit index """ @@ -55,11 +58,12 @@ class UnsignedDivRem: :param log2_radix: the base-2 log of the division radix. The number of bits of quotient that are calculated per pipeline stage. """ - self.remainder = Const.normalize(dividend, (bit_width, False)) + self.dividend = Const.normalize(dividend, (bit_width, False)) self.divisor = Const.normalize(divisor, (bit_width, False)) self.bit_width = bit_width self.log2_radix = log2_radix self.quotient = 0 + self.quotient_times_divisor = self.quotient * self.divisor self.current_shift = bit_width def calculate_stage(self): @@ -73,17 +77,23 @@ class UnsignedDivRem: assert log2_radix > 0 self.current_shift -= log2_radix radix = 1 << log2_radix - remainders = [] + trial_values = [] for i in range(radix): - v = (self.divisor * i) << self.current_shift - remainders.append(self.remainder - v) + v = self.quotient_times_divisor + v += (self.divisor * i) << self.current_shift + trial_values.append(v) quotient_bits = 0 + next_product = self.quotient_times_divisor for i in range(radix): - if remainders[i] >= 0: + if self.dividend >= trial_values[i]: quotient_bits = i - self.remainder = remainders[quotient_bits] + next_product = trial_values[i] + self.quotient_times_divisor = next_product self.quotient |= quotient_bits << self.current_shift - return self.current_shift == 0 + if self.current_shift == 0: + self.remainder = self.dividend - self.quotient_times_divisor + return True + return False def calculate(self): """ Calculate the results of the division. @@ -676,3 +686,156 @@ class FixedRSqrt: while not self.calculate_stage(): pass return self + + +class Operation(enum.Enum): + """ Operation for ``FixedUDivRemSqrtRSqrt``. """ + + UDivRem = "unsigned-divide/remainder" + SqrtRem = "square-root/remainder" + RSqrtRem = "reciprocal-square-root/remainder" + + +class FixedUDivRemSqrtRSqrt: + """ Combined class for computing fixed-point unsigned div/rem/sqrt/rsqrt. + + Algorithm based on ``UnsignedDivRem``, ``FixedSqrt``, and ``FixedRSqrt``. + + Formulas solved are: + * div/rem: + ``dividend == quotient_root * divisor_radicand`` + * sqrt/rem: + ``divisor_radicand == quotient_root * quotient_root`` + * rsqrt/rem: + ``1 == quotient_root * quotient_root * divisor_radicand`` + + The remainder is the left-hand-side of the comparison minus the + right-hand-side of the comparison in the above formulas. + + Important: not all variables have the same bit-width or fract-width. For + instance, ``dividend`` has a bit-width of ``bit_width + fract_width`` + and a fract-width of ``2 * fract_width`` bits. + + :attribute dividend: dividend for div/rem. Variable with a bit-width of + ``bit_width + fract_width`` and a fract-width of ``fract_width * 2`` + bits. + :attribute divisor_radicand: divisor for div/rem and radicand for + sqrt/rsqrt. Variable with a bit-width of ``bit_width`` and a + fract-width of ``fract_width`` bits. + :attribute operation: the ``Operation`` to be computed. + :attribute quotient_root: the quotient or root part of the result of the + operation. Variable with a bit-width of ``bit_width`` and a fract-width + of ``fract_width`` bits. + :attribute remainder: the remainder part of the result of the operation. + Variable with a bit-width of ``bit_width * 3`` and a fract-width + of ``fract_width * 3`` bits. + :attribute root_times_radicand: ``quotient_root * divisor_radicand``. + Variable with a bit-width of ``bit_width * 2`` and a fract-width of + ``fract_width * 2`` bits. + :attribute compare_lhs: The left-hand-side of the comparison in the + equation to be solved. Variable with a bit-width of ``bit_width * 3`` + and a fract-width of ``fract_width * 3`` bits. + :attribute compare_rhs: The right-hand-side of the comparison in the + equation to be solved. Variable with a bit-width of ``bit_width * 3`` + and a fract-width of ``fract_width * 3`` bits. + :attribute bit_width: base bit-width. Constant int. + :attribute fract_width: base fract-width. Specifies location of base-2 + radix point. Constant int. + :attribute log2_radix: number of bits of ``quotient_root`` that should be + computed per pipeline stage (invocation of ``calculate_stage``). + Constant int. + :attribute current_shift: the current bit index. Variable int. + """ + + def __init__(self, + dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix): + """ Create a new ``FixedUDivRemSqrtRSqrt``. + + :param dividend: ``dividend`` attribute's initializer. + :param divisor_radicand: ``divisor_radicand`` attribute's initializer. + :param operation: ``operation`` attribute's initializer. + :param bit_width: ``bit_width`` attribute's initializer. + :param fract_width: ``fract_width`` attribute's initializer. + :param log2_radix: ``log2_radix`` attribute's initializer. + """ + assert bit_width > 0 + assert fract_width >= 0 + assert fract_width <= bit_width + assert log2_radix > 0 + self.dividend = Const.normalize(dividend, + (bit_width + fract_width, False)) + self.divisor_radicand = Const.normalize(divisor_radicand, + (bit_width, False)) + self.quotient_root = 0 + self.root_times_radicand = 0 + if operation is Operation.UDivRem: + self.compare_lhs = self.dividend << fract_width + elif operation is Operation.SqrtRem: + self.compare_lhs = self.divisor_radicand << (fract_width * 2) + else: + assert operation is Operation.RSqrtRem + self.compare_lhs = 1 << (fract_width * 3) + self.compare_rhs = 0 + self.remainder = self.compare_lhs + self.operation = operation + self.bit_width = bit_width + self.fract_width = fract_width + self.log2_radix = log2_radix + self.current_shift = bit_width + + def calculate_stage(self): + """ Calculate the next pipeline stage of the operation. + + :returns bool: True if this is the last pipeline stage. + """ + if self.current_shift == 0: + return True + log2_radix = min(self.log2_radix, self.current_shift) + assert log2_radix > 0 + self.current_shift -= log2_radix + radix = 1 << log2_radix + trial_compare_rhs_values = [] + for trial_bits in range(radix): + shifted_trial_bits = trial_bits << self.current_shift + shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits + v = self.compare_rhs + if self.operation is Operation.UDivRem: + factor1 = self.divisor_radicand * shifted_trial_bits + v += factor1 << self.fract_width + elif self.operation is Operation.SqrtRem: + factor1 = self.quotient_root * (shifted_trial_bits << 1) + v += factor1 << self.fract_width + factor2 = shifted_trial_bits_sqrd + v += factor2 << self.fract_width + else: + assert self.operation is Operation.RSqrtRem + factor1 = self.root_times_radicand * (shifted_trial_bits << 1) + v += factor1 + factor2 = self.divisor_radicand * shifted_trial_bits_sqrd + v += factor2 + trial_compare_rhs_values.append(v) + shifted_next_bits = 0 + next_compare_rhs = trial_compare_rhs_values[0] + for trial_bits in range(radix): + if self.compare_lhs >= trial_compare_rhs_values[trial_bits]: + shifted_next_bits = trial_bits << self.current_shift + next_compare_rhs = trial_compare_rhs_values[trial_bits] + self.root_times_radicand += self.divisor_radicand * shifted_next_bits + self.compare_rhs = next_compare_rhs + self.quotient_root |= shifted_next_bits + self.remainder = self.compare_lhs - self.compare_rhs + return self.current_shift == 0 + + def calculate(self): + """ Calculate the results of the operation. + + :returns: self + """ + while not self.calculate_stage(): + pass + return self