from nmigen.hdl.ast import Const
import math
+import enum
def div_rem(dividend, divisor, bit_width, signed):
while not self.calculate_stage():
pass
return self
+
+
+class Operation(enum.Enum):
+ """ Operation for ``FixedUDivRemSqrtRSqrt``. """
+
+ UDivRem = "unsigned-divide/remainder"
+ SqrtRem = "square-root/remainder"
+ RSqrtRem = "reciprocal-square-root/remainder"
+
+
+class FixedUDivRemSqrtRSqrt:
+ """ Combined class for computing fixed-point unsigned div/rem/sqrt/rsqrt.
+
+ Algorithm based on ``UnsignedDivRem``, ``FixedSqrt``, and ``FixedRSqrt``.
+
+ Formulas solved are:
+ * div/rem:
+ ``dividend == quotient_root * divisor_radicand``
+ * sqrt/rem:
+ ``divisor_radicand == quotient_root * quotient_root``
+ * rsqrt/rem:
+ ``1 == quotient_root * quotient_root * divisor_radicand``
+
+ The remainder is the left-hand-side of the comparison minus the
+ right-hand-side of the comparison in the above formulas.
+
+ Important: not all variables have the same bit-width or fract-width. For
+ instance, ``dividend`` has a bit-width of ``bit_width + fract_width``
+ and a fract-width of ``2 * fract_width`` bits.
+
+ :attribute dividend: dividend for div/rem. Variable with a bit-width of
+ ``bit_width + fract_width`` and a fract-width of ``fract_width * 2``
+ bits.
+ :attribute divisor_radicand: divisor for div/rem and radicand for
+ sqrt/rsqrt. Variable with a bit-width of ``bit_width`` and a
+ fract-width of ``fract_width`` bits.
+ :attribute operation: the ``Operation`` to be computed.
+ :attribute quotient_root: the quotient or root part of the result of the
+ operation. Variable with a bit-width of ``bit_width`` and a fract-width
+ of ``fract_width`` bits.
+ :attribute remainder: the remainder part of the result of the operation.
+ Variable with a bit-width of ``bit_width * 3`` and a fract-width
+ of ``fract_width * 3`` bits.
+ :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
+ Variable with a bit-width of ``bit_width * 2`` and a fract-width of
+ ``fract_width * 2`` bits.
+ :attribute compare_lhs: The left-hand-side of the comparison in the
+ equation to be solved. Variable with a bit-width of ``bit_width * 3``
+ and a fract-width of ``fract_width * 3`` bits.
+ :attribute compare_rhs: The right-hand-side of the comparison in the
+ equation to be solved. Variable with a bit-width of ``bit_width * 3``
+ and a fract-width of ``fract_width * 3`` bits.
+ :attribute bit_width: base bit-width. Constant int.
+ :attribute fract_width: base fract-width. Specifies location of base-2
+ radix point. Constant int.
+ :attribute log2_radix: number of bits of ``quotient_root`` that should be
+ computed per pipeline stage (invocation of ``calculate_stage``).
+ Constant int.
+ :attribute current_shift: the current bit index. Variable int.
+ """
+
+ def __init__(self,
+ dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix):
+ """ Create a new ``FixedUDivRemSqrtRSqrt``.
+
+ :param dividend: ``dividend`` attribute's initializer.
+ :param divisor_radicand: ``divisor_radicand`` attribute's initializer.
+ :param operation: ``operation`` attribute's initializer.
+ :param bit_width: ``bit_width`` attribute's initializer.
+ :param fract_width: ``fract_width`` attribute's initializer.
+ :param log2_radix: ``log2_radix`` attribute's initializer.
+ """
+ assert bit_width > 0
+ assert fract_width >= 0
+ assert fract_width <= bit_width
+ assert log2_radix > 0
+ self.dividend = Const.normalize(dividend,
+ (bit_width + fract_width, False))
+ self.divisor_radicand = Const.normalize(divisor_radicand,
+ (bit_width, False))
+ self.quotient_root = 0
+ self.root_times_radicand = 0
+ if operation is Operation.UDivRem:
+ self.compare_lhs = self.dividend << fract_width
+ elif operation is Operation.SqrtRem:
+ self.compare_lhs = self.divisor_radicand << (fract_width * 2)
+ else:
+ assert operation is Operation.RSqrtRem
+ self.compare_lhs = 1 << (fract_width * 3)
+ self.compare_rhs = 0
+ self.remainder = self.compare_lhs
+ self.operation = operation
+ self.bit_width = bit_width
+ self.fract_width = fract_width
+ self.log2_radix = log2_radix
+ self.current_shift = bit_width
+
+ def calculate_stage(self):
+ """ Calculate the next pipeline stage of the operation.
+
+ :returns bool: True if this is the last pipeline stage.
+ """
+ if self.current_shift == 0:
+ return True
+ log2_radix = min(self.log2_radix, self.current_shift)
+ assert log2_radix > 0
+ self.current_shift -= log2_radix
+ radix = 1 << log2_radix
+ trial_compare_rhs_values = []
+ for trial_bits in range(radix):
+ shifted_trial_bits = trial_bits << self.current_shift
+ shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
+ v = self.compare_rhs
+ if self.operation is Operation.UDivRem:
+ factor1 = self.divisor_radicand * shifted_trial_bits
+ v += factor1 << self.fract_width
+ elif self.operation is Operation.SqrtRem:
+ factor1 = self.quotient_root * (shifted_trial_bits << 1)
+ v += factor1 << self.fract_width
+ factor2 = shifted_trial_bits_sqrd
+ v += factor2 << self.fract_width
+ else:
+ assert self.operation is Operation.RSqrtRem
+ factor1 = self.root_times_radicand * (shifted_trial_bits << 1)
+ v += factor1
+ factor2 = self.divisor_radicand * shifted_trial_bits_sqrd
+ v += factor2
+ trial_compare_rhs_values.append(v)
+ shifted_next_bits = 0
+ next_compare_rhs = trial_compare_rhs_values[0]
+ for trial_bits in range(radix):
+ if self.compare_lhs >= trial_compare_rhs_values[trial_bits]:
+ shifted_next_bits = trial_bits << self.current_shift
+ next_compare_rhs = trial_compare_rhs_values[trial_bits]
+ self.root_times_radicand += self.divisor_radicand * shifted_next_bits
+ self.compare_rhs = next_compare_rhs
+ self.quotient_root |= shifted_next_bits
+ self.remainder = self.compare_lhs - self.compare_rhs
+ return self.current_shift == 0
+
+ def calculate(self):
+ """ Calculate the results of the operation.
+
+ :returns: self
+ """
+ while not self.calculate_stage():
+ pass
+ return self
from nmigen.hdl.ast import Const
from .algorithm import (div_rem, UnsignedDivRem, DivRem,
Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
- fixed_rsqrt, FixedRSqrt)
+ fixed_rsqrt, FixedRSqrt, Operation,
+ FixedUDivRemSqrtRSqrt)
import unittest
import math
def test_radix_16(self):
self.helper(4)
+
+
+class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
+ @staticmethod
+ def show_fixed(bits, fract_width, bit_width):
+ fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
+ return f"{str(fixed)}:{repr(fixed)}"
+
+ def check_invariants(self,
+ dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix,
+ obj):
+ self.assertEqual(obj.dividend, dividend)
+ self.assertEqual(obj.divisor_radicand, divisor_radicand)
+ self.assertEqual(obj.operation, operation)
+ self.assertEqual(obj.bit_width, bit_width)
+ self.assertEqual(obj.fract_width, fract_width)
+ self.assertEqual(obj.log2_radix, log2_radix)
+ self.assertEqual(obj.root_times_radicand,
+ obj.quotient_root * obj.divisor_radicand)
+ self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
+ self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
+ if operation is Operation.UDivRem:
+ self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
+ self.assertEqual(obj.compare_rhs,
+ (obj.quotient_root * obj.divisor_radicand)
+ << fract_width)
+ elif operation is Operation.SqrtRem:
+ self.assertEqual(obj.compare_lhs,
+ obj.divisor_radicand << (fract_width * 2))
+ self.assertEqual(obj.compare_rhs,
+ (obj.quotient_root * obj.quotient_root)
+ << fract_width)
+ else:
+ assert operation is Operation.RSqrtRem
+ self.assertEqual(obj.compare_lhs,
+ 1 << (fract_width * 3))
+ self.assertEqual(obj.compare_rhs,
+ obj.quotient_root * obj.quotient_root
+ * obj.divisor_radicand)
+
+ def handle_case(self,
+ dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix):
+ dividend_str = self.show_fixed(dividend,
+ fract_width * 2,
+ bit_width + fract_width)
+ divisor_radicand_str = self.show_fixed(divisor_radicand,
+ fract_width,
+ bit_width)
+ with self.subTest(dividend=dividend_str,
+ divisor_radicand=divisor_radicand_str,
+ operation=operation.name,
+ bit_width=bit_width,
+ fract_width=fract_width,
+ log2_radix=log2_radix):
+ if operation is Operation.UDivRem:
+ if divisor_radicand == 0:
+ return
+ quotient_root, remainder = div_rem(dividend,
+ divisor_radicand,
+ bit_width * 3,
+ False)
+ remainder <<= fract_width
+ elif operation is Operation.SqrtRem:
+ root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
+ fract_width,
+ bit_width,
+ False))
+ self.assertEqual(root_remainder.root.bit_width,
+ bit_width)
+ self.assertEqual(root_remainder.root.fract_width,
+ fract_width)
+ self.assertEqual(root_remainder.remainder.bit_width,
+ bit_width * 2)
+ self.assertEqual(root_remainder.remainder.fract_width,
+ fract_width * 2)
+ quotient_root = root_remainder.root.bits
+ remainder = root_remainder.remainder.bits << fract_width
+ else:
+ assert operation is Operation.RSqrtRem
+ if divisor_radicand == 0:
+ return
+ root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
+ fract_width,
+ bit_width,
+ False))
+ self.assertEqual(root_remainder.root.bit_width,
+ bit_width)
+ self.assertEqual(root_remainder.root.fract_width,
+ fract_width)
+ self.assertEqual(root_remainder.remainder.bit_width,
+ bit_width * 3)
+ self.assertEqual(root_remainder.remainder.fract_width,
+ fract_width * 3)
+ quotient_root = root_remainder.root.bits
+ remainder = root_remainder.remainder.bits
+ if quotient_root >= (1 << bit_width):
+ return
+ quotient_root_str = self.show_fixed(quotient_root,
+ fract_width,
+ bit_width)
+ remainder_str = self.show_fixed(remainder,
+ fract_width * 3,
+ bit_width * 3)
+ with self.subTest(quotient_root=quotient_root_str,
+ remainder=remainder_str):
+ obj = FixedUDivRemSqrtRSqrt(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix)
+ for _ in range(250 * bit_width):
+ self.check_invariants(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix,
+ obj)
+ if obj.calculate_stage():
+ break
+ else:
+ self.fail("infinite loop")
+ self.check_invariants(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix,
+ obj)
+ self.assertEqual(obj.quotient_root, quotient_root)
+ self.assertEqual(obj.remainder, remainder)
+
+ def helper(self, log2_radix, operation):
+ bit_width_range = range(1, 8)
+ if operation is Operation.UDivRem:
+ bit_width_range = range(1, 6)
+ for bit_width in bit_width_range:
+ for fract_width in range(bit_width):
+ for divisor_radicand in range(1 << bit_width):
+ dividend_range = range(1)
+ if operation is Operation.UDivRem:
+ dividend_range = range(1 << (bit_width + fract_width))
+ for dividend in dividend_range:
+ self.handle_case(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix)
+
+ def test_radix_2_UDiv(self):
+ self.helper(1, Operation.UDivRem)
+
+ def test_radix_4_UDiv(self):
+ self.helper(2, Operation.UDivRem)
+
+ def test_radix_8_UDiv(self):
+ self.helper(3, Operation.UDivRem)
+
+ def test_radix_16_UDiv(self):
+ self.helper(4, Operation.UDivRem)
+
+ def test_radix_2_Sqrt(self):
+ self.helper(1, Operation.SqrtRem)
+
+ def test_radix_4_Sqrt(self):
+ self.helper(2, Operation.SqrtRem)
+
+ def test_radix_8_Sqrt(self):
+ self.helper(3, Operation.SqrtRem)
+
+ def test_radix_16_Sqrt(self):
+ self.helper(4, Operation.SqrtRem)
+
+ def test_radix_2_RSqrt(self):
+ self.helper(1, Operation.RSqrtRem)
+
+ def test_radix_4_RSqrt(self):
+ self.helper(2, Operation.RSqrtRem)
+
+ def test_radix_8_RSqrt(self):
+ self.helper(3, Operation.RSqrtRem)
+
+ def test_radix_16_RSqrt(self):
+ self.helper(4, Operation.RSqrtRem)