From 4b48577681ff12d8e3e8b88886c5f7b23e5ec718 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 3 Jul 2019 01:46:42 -0700 Subject: [PATCH] implemented FixedUDivRemSqrtRSqrt --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 154 ++++++++++++++ .../div_rem_sqrt_rsqrt/test_algorithm.py | 199 +++++++++++++++++- 2 files changed, 352 insertions(+), 1 deletion(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index 7fa051f2..6ba28311 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -8,6 +8,7 @@ code for simulating/testing the various algorithms from nmigen.hdl.ast import Const import math +import enum def div_rem(dividend, divisor, bit_width, signed): @@ -676,3 +677,156 @@ class FixedRSqrt: while not self.calculate_stage(): pass return self + + +class Operation(enum.Enum): + """ Operation for ``FixedUDivRemSqrtRSqrt``. """ + + UDivRem = "unsigned-divide/remainder" + SqrtRem = "square-root/remainder" + RSqrtRem = "reciprocal-square-root/remainder" + + +class FixedUDivRemSqrtRSqrt: + """ Combined class for computing fixed-point unsigned div/rem/sqrt/rsqrt. + + Algorithm based on ``UnsignedDivRem``, ``FixedSqrt``, and ``FixedRSqrt``. + + Formulas solved are: + * div/rem: + ``dividend == quotient_root * divisor_radicand`` + * sqrt/rem: + ``divisor_radicand == quotient_root * quotient_root`` + * rsqrt/rem: + ``1 == quotient_root * quotient_root * divisor_radicand`` + + The remainder is the left-hand-side of the comparison minus the + right-hand-side of the comparison in the above formulas. + + Important: not all variables have the same bit-width or fract-width. For + instance, ``dividend`` has a bit-width of ``bit_width + fract_width`` + and a fract-width of ``2 * fract_width`` bits. + + :attribute dividend: dividend for div/rem. Variable with a bit-width of + ``bit_width + fract_width`` and a fract-width of ``fract_width * 2`` + bits. + :attribute divisor_radicand: divisor for div/rem and radicand for + sqrt/rsqrt. Variable with a bit-width of ``bit_width`` and a + fract-width of ``fract_width`` bits. + :attribute operation: the ``Operation`` to be computed. + :attribute quotient_root: the quotient or root part of the result of the + operation. Variable with a bit-width of ``bit_width`` and a fract-width + of ``fract_width`` bits. + :attribute remainder: the remainder part of the result of the operation. + Variable with a bit-width of ``bit_width * 3`` and a fract-width + of ``fract_width * 3`` bits. + :attribute root_times_radicand: ``quotient_root * divisor_radicand``. + Variable with a bit-width of ``bit_width * 2`` and a fract-width of + ``fract_width * 2`` bits. + :attribute compare_lhs: The left-hand-side of the comparison in the + equation to be solved. Variable with a bit-width of ``bit_width * 3`` + and a fract-width of ``fract_width * 3`` bits. + :attribute compare_rhs: The right-hand-side of the comparison in the + equation to be solved. Variable with a bit-width of ``bit_width * 3`` + and a fract-width of ``fract_width * 3`` bits. + :attribute bit_width: base bit-width. Constant int. + :attribute fract_width: base fract-width. Specifies location of base-2 + radix point. Constant int. + :attribute log2_radix: number of bits of ``quotient_root`` that should be + computed per pipeline stage (invocation of ``calculate_stage``). + Constant int. + :attribute current_shift: the current bit index. Variable int. + """ + + def __init__(self, + dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix): + """ Create a new ``FixedUDivRemSqrtRSqrt``. + + :param dividend: ``dividend`` attribute's initializer. + :param divisor_radicand: ``divisor_radicand`` attribute's initializer. + :param operation: ``operation`` attribute's initializer. + :param bit_width: ``bit_width`` attribute's initializer. + :param fract_width: ``fract_width`` attribute's initializer. + :param log2_radix: ``log2_radix`` attribute's initializer. + """ + assert bit_width > 0 + assert fract_width >= 0 + assert fract_width <= bit_width + assert log2_radix > 0 + self.dividend = Const.normalize(dividend, + (bit_width + fract_width, False)) + self.divisor_radicand = Const.normalize(divisor_radicand, + (bit_width, False)) + self.quotient_root = 0 + self.root_times_radicand = 0 + if operation is Operation.UDivRem: + self.compare_lhs = self.dividend << fract_width + elif operation is Operation.SqrtRem: + self.compare_lhs = self.divisor_radicand << (fract_width * 2) + else: + assert operation is Operation.RSqrtRem + self.compare_lhs = 1 << (fract_width * 3) + self.compare_rhs = 0 + self.remainder = self.compare_lhs + self.operation = operation + self.bit_width = bit_width + self.fract_width = fract_width + self.log2_radix = log2_radix + self.current_shift = bit_width + + def calculate_stage(self): + """ Calculate the next pipeline stage of the operation. + + :returns bool: True if this is the last pipeline stage. + """ + if self.current_shift == 0: + return True + log2_radix = min(self.log2_radix, self.current_shift) + assert log2_radix > 0 + self.current_shift -= log2_radix + radix = 1 << log2_radix + trial_compare_rhs_values = [] + for trial_bits in range(radix): + shifted_trial_bits = trial_bits << self.current_shift + shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits + v = self.compare_rhs + if self.operation is Operation.UDivRem: + factor1 = self.divisor_radicand * shifted_trial_bits + v += factor1 << self.fract_width + elif self.operation is Operation.SqrtRem: + factor1 = self.quotient_root * (shifted_trial_bits << 1) + v += factor1 << self.fract_width + factor2 = shifted_trial_bits_sqrd + v += factor2 << self.fract_width + else: + assert self.operation is Operation.RSqrtRem + factor1 = self.root_times_radicand * (shifted_trial_bits << 1) + v += factor1 + factor2 = self.divisor_radicand * shifted_trial_bits_sqrd + v += factor2 + trial_compare_rhs_values.append(v) + shifted_next_bits = 0 + next_compare_rhs = trial_compare_rhs_values[0] + for trial_bits in range(radix): + if self.compare_lhs >= trial_compare_rhs_values[trial_bits]: + shifted_next_bits = trial_bits << self.current_shift + next_compare_rhs = trial_compare_rhs_values[trial_bits] + self.root_times_radicand += self.divisor_radicand * shifted_next_bits + self.compare_rhs = next_compare_rhs + self.quotient_root |= shifted_next_bits + self.remainder = self.compare_lhs - self.compare_rhs + return self.current_shift == 0 + + def calculate(self): + """ Calculate the results of the operation. + + :returns: self + """ + while not self.calculate_stage(): + pass + return self diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index 081bda97..c5c3e7b3 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -4,7 +4,8 @@ from nmigen.hdl.ast import Const from .algorithm import (div_rem, UnsignedDivRem, DivRem, Fixed, RootRemainder, fixed_sqrt, FixedSqrt, - fixed_rsqrt, FixedRSqrt) + fixed_rsqrt, FixedRSqrt, Operation, + FixedUDivRemSqrtRSqrt) import unittest import math @@ -872,3 +873,199 @@ class TestFixedRSqrt(unittest.TestCase): def test_radix_16(self): self.helper(4) + + +class TestFixedUDivRemSqrtRSqrt(unittest.TestCase): + @staticmethod + def show_fixed(bits, fract_width, bit_width): + fixed = Fixed.from_bits(bits, fract_width, bit_width, False) + return f"{str(fixed)}:{repr(fixed)}" + + def check_invariants(self, + dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix, + obj): + self.assertEqual(obj.dividend, dividend) + self.assertEqual(obj.divisor_radicand, divisor_radicand) + self.assertEqual(obj.operation, operation) + self.assertEqual(obj.bit_width, bit_width) + self.assertEqual(obj.fract_width, fract_width) + self.assertEqual(obj.log2_radix, log2_radix) + self.assertEqual(obj.root_times_radicand, + obj.quotient_root * obj.divisor_radicand) + self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs) + self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs) + if operation is Operation.UDivRem: + self.assertEqual(obj.compare_lhs, obj.dividend << fract_width) + self.assertEqual(obj.compare_rhs, + (obj.quotient_root * obj.divisor_radicand) + << fract_width) + elif operation is Operation.SqrtRem: + self.assertEqual(obj.compare_lhs, + obj.divisor_radicand << (fract_width * 2)) + self.assertEqual(obj.compare_rhs, + (obj.quotient_root * obj.quotient_root) + << fract_width) + else: + assert operation is Operation.RSqrtRem + self.assertEqual(obj.compare_lhs, + 1 << (fract_width * 3)) + self.assertEqual(obj.compare_rhs, + obj.quotient_root * obj.quotient_root + * obj.divisor_radicand) + + def handle_case(self, + dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix): + dividend_str = self.show_fixed(dividend, + fract_width * 2, + bit_width + fract_width) + divisor_radicand_str = self.show_fixed(divisor_radicand, + fract_width, + bit_width) + with self.subTest(dividend=dividend_str, + divisor_radicand=divisor_radicand_str, + operation=operation.name, + bit_width=bit_width, + fract_width=fract_width, + log2_radix=log2_radix): + if operation is Operation.UDivRem: + if divisor_radicand == 0: + return + quotient_root, remainder = div_rem(dividend, + divisor_radicand, + bit_width * 3, + False) + remainder <<= fract_width + elif operation is Operation.SqrtRem: + root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand, + fract_width, + bit_width, + False)) + self.assertEqual(root_remainder.root.bit_width, + bit_width) + self.assertEqual(root_remainder.root.fract_width, + fract_width) + self.assertEqual(root_remainder.remainder.bit_width, + bit_width * 2) + self.assertEqual(root_remainder.remainder.fract_width, + fract_width * 2) + quotient_root = root_remainder.root.bits + remainder = root_remainder.remainder.bits << fract_width + else: + assert operation is Operation.RSqrtRem + if divisor_radicand == 0: + return + root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand, + fract_width, + bit_width, + False)) + self.assertEqual(root_remainder.root.bit_width, + bit_width) + self.assertEqual(root_remainder.root.fract_width, + fract_width) + self.assertEqual(root_remainder.remainder.bit_width, + bit_width * 3) + self.assertEqual(root_remainder.remainder.fract_width, + fract_width * 3) + quotient_root = root_remainder.root.bits + remainder = root_remainder.remainder.bits + if quotient_root >= (1 << bit_width): + return + quotient_root_str = self.show_fixed(quotient_root, + fract_width, + bit_width) + remainder_str = self.show_fixed(remainder, + fract_width * 3, + bit_width * 3) + with self.subTest(quotient_root=quotient_root_str, + remainder=remainder_str): + obj = FixedUDivRemSqrtRSqrt(dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix) + for _ in range(250 * bit_width): + self.check_invariants(dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix, + obj) + if obj.calculate_stage(): + break + else: + self.fail("infinite loop") + self.check_invariants(dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix, + obj) + self.assertEqual(obj.quotient_root, quotient_root) + self.assertEqual(obj.remainder, remainder) + + def helper(self, log2_radix, operation): + bit_width_range = range(1, 8) + if operation is Operation.UDivRem: + bit_width_range = range(1, 6) + for bit_width in bit_width_range: + for fract_width in range(bit_width): + for divisor_radicand in range(1 << bit_width): + dividend_range = range(1) + if operation is Operation.UDivRem: + dividend_range = range(1 << (bit_width + fract_width)) + for dividend in dividend_range: + self.handle_case(dividend, + divisor_radicand, + operation, + bit_width, + fract_width, + log2_radix) + + def test_radix_2_UDiv(self): + self.helper(1, Operation.UDivRem) + + def test_radix_4_UDiv(self): + self.helper(2, Operation.UDivRem) + + def test_radix_8_UDiv(self): + self.helper(3, Operation.UDivRem) + + def test_radix_16_UDiv(self): + self.helper(4, Operation.UDivRem) + + def test_radix_2_Sqrt(self): + self.helper(1, Operation.SqrtRem) + + def test_radix_4_Sqrt(self): + self.helper(2, Operation.SqrtRem) + + def test_radix_8_Sqrt(self): + self.helper(3, Operation.SqrtRem) + + def test_radix_16_Sqrt(self): + self.helper(4, Operation.SqrtRem) + + def test_radix_2_RSqrt(self): + self.helper(1, Operation.RSqrtRem) + + def test_radix_4_RSqrt(self): + self.helper(2, Operation.RSqrtRem) + + def test_radix_8_RSqrt(self): + self.helper(3, Operation.RSqrtRem) + + def test_radix_16_RSqrt(self): + self.helper(4, Operation.RSqrtRem) -- 2.30.2