+#!/usr/bin/env python3
# SPDX-License-Identifier: LGPL-2.1-or-later
# See Notices.txt for copyright information
from nmigen.hdl.ast import Const
from .algorithm import (div_rem, UnsignedDivRem, DivRem,
Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
- fixed_rsqrt, FixedRSqrt)
+ fixed_rsqrt, FixedRSqrt, Operation,
+ FixedUDivRemSqrtRSqrt)
import unittest
import math
with self.subTest(n=n, d=d, q=q, r=r):
udr = UnsignedDivRem(n, d, bit_width, log2_radix)
for _ in range(250 * bit_width):
- self.assertEqual(n, udr.quotient * udr.divisor
- + udr.remainder)
+ self.assertEqual(udr.dividend, n)
+ self.assertEqual(udr.divisor, d)
+ self.assertEqual(udr.quotient_times_divisor,
+ udr.quotient * udr.divisor)
+ self.assertGreaterEqual(udr.dividend,
+ udr.quotient_times_divisor)
if udr.calculate_stage():
break
else:
self.fail("infinite loop")
- self.assertEqual(n, udr.quotient * udr.divisor
- + udr.remainder)
+ self.assertEqual(udr.dividend, n)
+ self.assertEqual(udr.divisor, d)
+ self.assertEqual(udr.quotient_times_divisor,
+ udr.quotient * udr.divisor)
+ self.assertGreaterEqual(udr.dividend,
+ udr.quotient_times_divisor)
self.assertEqual(udr.quotient, q)
self.assertEqual(udr.remainder, r)
self.assertEqual(value.bit_width, 8)
self.assertEqual(value.signed, True)
- def helper_test_from_bits(self, bit_width, fract_width):
+ def helper_tst_from_bits(self, bit_width, fract_width):
signed = False
for bits in range(1 << bit_width):
with self.subTest(bit_width=bit_width,
def test_from_bits(self):
for bit_width in range(1, 5):
for fract_width in range(bit_width):
- self.helper_test_from_bits(bit_width, fract_width)
+ self.helper_tst_from_bits(bit_width, fract_width)
def test_repr(self):
self.assertEqual(repr(Fixed.from_bits(1, 2, 3, False)),
def test_radix_16(self):
self.helper(4)
+
+
+class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
+ @staticmethod
+ def show_fixed(bits, fract_width, bit_width):
+ fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
+ return f"{str(fixed)}:{repr(fixed)}"
+
+ def check_invariants(self,
+ dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix,
+ obj):
+ self.assertEqual(obj.dividend, dividend)
+ self.assertEqual(obj.divisor_radicand, divisor_radicand)
+ self.assertEqual(obj.operation, operation)
+ self.assertEqual(obj.bit_width, bit_width)
+ self.assertEqual(obj.fract_width, fract_width)
+ self.assertEqual(obj.log2_radix, log2_radix)
+ self.assertEqual(obj.root_times_radicand,
+ obj.quotient_root * obj.divisor_radicand)
+ self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
+ self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
+ if operation is Operation.UDivRem:
+ self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
+ self.assertEqual(obj.compare_rhs,
+ (obj.quotient_root * obj.divisor_radicand)
+ << fract_width)
+ elif operation is Operation.SqrtRem:
+ self.assertEqual(obj.compare_lhs,
+ obj.divisor_radicand << (fract_width * 2))
+ self.assertEqual(obj.compare_rhs,
+ (obj.quotient_root * obj.quotient_root)
+ << fract_width)
+ else:
+ assert operation is Operation.RSqrtRem
+ self.assertEqual(obj.compare_lhs,
+ 1 << (fract_width * 3))
+ self.assertEqual(obj.compare_rhs,
+ obj.quotient_root * obj.quotient_root
+ * obj.divisor_radicand)
+
+ def handle_case(self,
+ dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix):
+ dividend_str = self.show_fixed(dividend,
+ fract_width * 2,
+ bit_width + fract_width)
+ divisor_radicand_str = self.show_fixed(divisor_radicand,
+ fract_width,
+ bit_width)
+ with self.subTest(dividend=dividend_str,
+ divisor_radicand=divisor_radicand_str,
+ operation=operation.name,
+ bit_width=bit_width,
+ fract_width=fract_width,
+ log2_radix=log2_radix):
+ if operation is Operation.UDivRem:
+ if divisor_radicand == 0:
+ return
+ quotient_root, remainder = div_rem(dividend,
+ divisor_radicand,
+ bit_width * 3,
+ False)
+ remainder <<= fract_width
+ elif operation is Operation.SqrtRem:
+ root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
+ fract_width,
+ bit_width,
+ False))
+ self.assertEqual(root_remainder.root.bit_width,
+ bit_width)
+ self.assertEqual(root_remainder.root.fract_width,
+ fract_width)
+ self.assertEqual(root_remainder.remainder.bit_width,
+ bit_width * 2)
+ self.assertEqual(root_remainder.remainder.fract_width,
+ fract_width * 2)
+ quotient_root = root_remainder.root.bits
+ remainder = root_remainder.remainder.bits << fract_width
+ else:
+ assert operation is Operation.RSqrtRem
+ if divisor_radicand == 0:
+ return
+ root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
+ fract_width,
+ bit_width,
+ False))
+ self.assertEqual(root_remainder.root.bit_width,
+ bit_width)
+ self.assertEqual(root_remainder.root.fract_width,
+ fract_width)
+ self.assertEqual(root_remainder.remainder.bit_width,
+ bit_width * 3)
+ self.assertEqual(root_remainder.remainder.fract_width,
+ fract_width * 3)
+ quotient_root = root_remainder.root.bits
+ remainder = root_remainder.remainder.bits
+ if quotient_root >= (1 << bit_width):
+ return
+ quotient_root_str = self.show_fixed(quotient_root,
+ fract_width,
+ bit_width)
+ remainder_str = self.show_fixed(remainder,
+ fract_width * 3,
+ bit_width * 3)
+ with self.subTest(quotient_root=quotient_root_str,
+ remainder=remainder_str):
+ obj = FixedUDivRemSqrtRSqrt(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix)
+ for _ in range(250 * bit_width):
+ self.check_invariants(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix,
+ obj)
+ if obj.calculate_stage():
+ break
+ else:
+ self.fail("infinite loop")
+ self.check_invariants(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix,
+ obj)
+ self.assertEqual(obj.quotient_root, quotient_root)
+ self.assertEqual(obj.remainder, remainder)
+
+ def helper(self, log2_radix, operation):
+ bit_width_range = range(1, 8)
+ if operation is Operation.UDivRem:
+ bit_width_range = range(1, 6)
+ for bit_width in bit_width_range:
+ for fract_width in range(bit_width):
+ for divisor_radicand in range(1 << bit_width):
+ dividend_range = range(1)
+ if operation is Operation.UDivRem:
+ dividend_range = range(1 << (bit_width + fract_width))
+ for dividend in dividend_range:
+ self.handle_case(dividend,
+ divisor_radicand,
+ operation,
+ bit_width,
+ fract_width,
+ log2_radix)
+
+ def test_radix_2_UDiv(self):
+ self.helper(1, Operation.UDivRem)
+
+ def test_radix_4_UDiv(self):
+ self.helper(2, Operation.UDivRem)
+
+ def test_radix_8_UDiv(self):
+ self.helper(3, Operation.UDivRem)
+
+ def test_radix_16_UDiv(self):
+ self.helper(4, Operation.UDivRem)
+
+ def test_radix_2_Sqrt(self):
+ self.helper(1, Operation.SqrtRem)
+
+ def test_radix_4_Sqrt(self):
+ self.helper(2, Operation.SqrtRem)
+
+ def test_radix_8_Sqrt(self):
+ self.helper(3, Operation.SqrtRem)
+
+ def test_radix_16_Sqrt(self):
+ self.helper(4, Operation.SqrtRem)
+
+ def test_radix_2_RSqrt(self):
+ self.helper(1, Operation.RSqrtRem)
+
+ def test_radix_4_RSqrt(self):
+ self.helper(2, Operation.RSqrtRem)
+
+ def test_radix_8_RSqrt(self):
+ self.helper(3, Operation.RSqrtRem)
+
+ def test_radix_16_RSqrt(self):
+ self.helper(4, Operation.RSqrtRem)
+
+ def test_int_div(self):
+ bit_width = 8
+ fract_width = 4
+ log2_radix = 3
+ for dividend in range(1 << bit_width):
+ for divisor in range(1, 1 << bit_width):
+ obj = FixedUDivRemSqrtRSqrt(dividend,
+ divisor,
+ Operation.UDivRem,
+ bit_width,
+ fract_width,
+ log2_radix)
+ obj.calculate()
+ quotient, remainder = div_rem(dividend,
+ divisor,
+ bit_width,
+ False)
+ shifted_remainder = remainder << fract_width
+ with self.subTest(dividend=dividend,
+ divisor=divisor,
+ quotient=quotient,
+ remainder=remainder,
+ shifted_remainder=shifted_remainder):
+ self.assertEqual(obj.quotient_root, quotient)
+ self.assertEqual(obj.remainder, shifted_remainder)
+
+ def test_fract_div(self):
+ bit_width = 8
+ fract_width = 4
+ log2_radix = 3
+ for dividend in range(1 << bit_width):
+ for divisor in range(1, 1 << bit_width):
+ obj = FixedUDivRemSqrtRSqrt(dividend << fract_width,
+ divisor,
+ Operation.UDivRem,
+ bit_width,
+ fract_width,
+ log2_radix)
+ obj.calculate()
+ quotient = (dividend << fract_width) // divisor
+ if quotient >= (1 << bit_width):
+ continue
+ remainder = (dividend << fract_width) % divisor
+ shifted_remainder = remainder << fract_width
+ with self.subTest(dividend=dividend,
+ divisor=divisor,
+ quotient=quotient,
+ remainder=remainder,
+ shifted_remainder=shifted_remainder):
+ self.assertEqual(obj.quotient_root, quotient)
+ self.assertEqual(obj.remainder, shifted_remainder)
+
+
+if __name__ == '__main__':
+ unittest.main()