--- /dev/null
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+""" Algorithms for div/rem/sqrt/rsqrt.
+
+code for simulating/testing the various algorithms
+"""
+
+from nmigen.hdl.ast import Const
+
+
+def div_rem(dividend, divisor, bit_width, signed):
+ """ Compute the quotient/remainder following the RISC-V M extension.
+
+ NOT the same as the // or % operators
+ """
+ dividend = Const.normalize(dividend, (bit_width, signed))
+ divisor = Const.normalize(divisor, (bit_width, signed))
+ if divisor == 0:
+ quotient = -1
+ remainder = dividend
+ else:
+ quotient = abs(dividend) // abs(divisor)
+ remainder = abs(dividend) % abs(divisor)
+ if (dividend < 0) != (divisor < 0):
+ quotient = -quotient
+ if dividend < 0:
+ remainder = -remainder
+ quotient = Const.normalize(quotient, (bit_width, signed))
+ remainder = Const.normalize(remainder, (bit_width, signed))
+ return quotient, remainder
+
+
+class UnsignedDivRem:
+ """ Unsigned integer division/remainder following the RISC-V M extension.
+
+ NOT the same as the // or % operators
+
+ :attribute remainder: the remainder and/or dividend
+ :attribute divisor: the divisor
+ :attribute bit_width: the bit width of the inputs/outputs
+ :attribute log2_radix: the base-2 log of the division radix. The number of
+ bits of quotient that are calculated per pipeline stage.
+ :attribute quotient: the quotient
+ :attribute current_shift: the current bit index
+ """
+
+ def __init__(self, dividend, divisor, bit_width, log2_radix=3):
+ """ Create an UnsignedDivRem.
+
+ :param dividend: the dividend/numerator
+ :param divisor: the divisor/denominator
+ :param bit_width: the bit width of the inputs/outputs
+ :param log2_radix: the base-2 log of the division radix. The number of
+ bits of quotient that are calculated per pipeline stage.
+ """
+ self.remainder = Const.normalize(dividend, (bit_width, False))
+ self.divisor = Const.normalize(divisor, (bit_width, False))
+ self.bit_width = bit_width
+ self.log2_radix = log2_radix
+ self.quotient = 0
+ self.current_shift = bit_width
+
+ def calculate_stage(self):
+ """ Calculate the next pipeline stage of the division.
+
+ :returns bool: True if this is the last pipeline stage.
+ """
+ if self.current_shift == 0:
+ return True
+ log2_radix = min(self.log2_radix, self.current_shift)
+ assert log2_radix > 0
+ self.current_shift -= log2_radix
+ radix = 1 << log2_radix
+ remainders = []
+ for i in range(radix):
+ v = (self.divisor * i) << self.current_shift
+ remainders.append(self.remainder - v)
+ quotient_bits = 0
+ for i in range(radix):
+ if remainders[i] >= 0:
+ quotient_bits = i
+ self.remainder = remainders[quotient_bits]
+ self.quotient |= quotient_bits << self.current_shift
+ return self.current_shift == 0
+
+ def calculate(self):
+ """ Calculate the results of the division.
+
+ :returns: self
+ """
+ while not self.calculate_stage():
+ pass
+ return self
+
+
+class DivRem:
+ """ integer division/remainder following the RISC-V M extension.
+
+ NOT the same as the // or % operators
+
+ :attribute dividend: the dividend
+ :attribute divisor: the divisor
+ :attribute signed: if the inputs/outputs are signed instead of unsigned
+ :attribute quotient: the quotient
+ :attribute remainder: the remainder
+ :attribute divider: the base UnsignedDivRem
+ """
+
+ def __init__(self, dividend, divisor, bit_width, signed, log2_radix=3):
+ """ Create a DivRem.
+
+ :param dividend: the dividend/numerator
+ :param divisor: the divisor/denominator
+ :param bit_width: the bit width of the inputs/outputs
+ :param signed: if the inputs/outputs are signed instead of unsigned
+ :param log2_radix: the base-2 log of the division radix. The number of
+ bits of quotient that are calculated per pipeline stage.
+ """
+ self.dividend = Const.normalize(dividend, (bit_width, signed))
+ self.divisor = Const.normalize(divisor, (bit_width, signed))
+ self.signed = signed
+ self.quotient = 0
+ self.remainder = 0
+ self.divider = UnsignedDivRem(abs(dividend), abs(divisor),
+ bit_width, log2_radix)
+
+ def calculate_stage(self):
+ """ Calculate the next pipeline stage of the division.
+
+ :returns bool: True if this is the last pipeline stage.
+ """
+ if not self.divider.calculate_stage():
+ return False
+ divisor_sign = self.divisor < 0
+ dividend_sign = self.dividend < 0
+ if self.divisor != 0 and divisor_sign != dividend_sign:
+ quotient = -self.divider.quotient
+ else:
+ quotient = self.divider.quotient
+ if dividend_sign:
+ remainder = -self.divider.remainder
+ else:
+ remainder = self.divider.remainder
+ bit_width = self.divider.bit_width
+ self.quotient = Const.normalize(quotient, (bit_width, self.signed))
+ self.remainder = Const.normalize(remainder, (bit_width, self.signed))
+ return True
--- /dev/null
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+from nmigen.hdl.ast import Const
+from .algorithm import div_rem, UnsignedDivRem, DivRem
+import unittest
+
+
+class TestDivRemFn(unittest.TestCase):
+ def test_signed(self):
+ test_cases = [
+ # numerator, denominator, quotient, remainder
+ (-8, -8, 1, 0),
+ (-7, -8, 0, -7),
+ (-6, -8, 0, -6),
+ (-5, -8, 0, -5),
+ (-4, -8, 0, -4),
+ (-3, -8, 0, -3),
+ (-2, -8, 0, -2),
+ (-1, -8, 0, -1),
+ (0, -8, 0, 0),
+ (1, -8, 0, 1),
+ (2, -8, 0, 2),
+ (3, -8, 0, 3),
+ (4, -8, 0, 4),
+ (5, -8, 0, 5),
+ (6, -8, 0, 6),
+ (7, -8, 0, 7),
+ (-8, -7, 1, -1),
+ (-7, -7, 1, 0),
+ (-6, -7, 0, -6),
+ (-5, -7, 0, -5),
+ (-4, -7, 0, -4),
+ (-3, -7, 0, -3),
+ (-2, -7, 0, -2),
+ (-1, -7, 0, -1),
+ (0, -7, 0, 0),
+ (1, -7, 0, 1),
+ (2, -7, 0, 2),
+ (3, -7, 0, 3),
+ (4, -7, 0, 4),
+ (5, -7, 0, 5),
+ (6, -7, 0, 6),
+ (7, -7, -1, 0),
+ (-8, -6, 1, -2),
+ (-7, -6, 1, -1),
+ (-6, -6, 1, 0),
+ (-5, -6, 0, -5),
+ (-4, -6, 0, -4),
+ (-3, -6, 0, -3),
+ (-2, -6, 0, -2),
+ (-1, -6, 0, -1),
+ (0, -6, 0, 0),
+ (1, -6, 0, 1),
+ (2, -6, 0, 2),
+ (3, -6, 0, 3),
+ (4, -6, 0, 4),
+ (5, -6, 0, 5),
+ (6, -6, -1, 0),
+ (7, -6, -1, 1),
+ (-8, -5, 1, -3),
+ (-7, -5, 1, -2),
+ (-6, -5, 1, -1),
+ (-5, -5, 1, 0),
+ (-4, -5, 0, -4),
+ (-3, -5, 0, -3),
+ (-2, -5, 0, -2),
+ (-1, -5, 0, -1),
+ (0, -5, 0, 0),
+ (1, -5, 0, 1),
+ (2, -5, 0, 2),
+ (3, -5, 0, 3),
+ (4, -5, 0, 4),
+ (5, -5, -1, 0),
+ (6, -5, -1, 1),
+ (7, -5, -1, 2),
+ (-8, -4, 2, 0),
+ (-7, -4, 1, -3),
+ (-6, -4, 1, -2),
+ (-5, -4, 1, -1),
+ (-4, -4, 1, 0),
+ (-3, -4, 0, -3),
+ (-2, -4, 0, -2),
+ (-1, -4, 0, -1),
+ (0, -4, 0, 0),
+ (1, -4, 0, 1),
+ (2, -4, 0, 2),
+ (3, -4, 0, 3),
+ (4, -4, -1, 0),
+ (5, -4, -1, 1),
+ (6, -4, -1, 2),
+ (7, -4, -1, 3),
+ (-8, -3, 2, -2),
+ (-7, -3, 2, -1),
+ (-6, -3, 2, 0),
+ (-5, -3, 1, -2),
+ (-4, -3, 1, -1),
+ (-3, -3, 1, 0),
+ (-2, -3, 0, -2),
+ (-1, -3, 0, -1),
+ (0, -3, 0, 0),
+ (1, -3, 0, 1),
+ (2, -3, 0, 2),
+ (3, -3, -1, 0),
+ (4, -3, -1, 1),
+ (5, -3, -1, 2),
+ (6, -3, -2, 0),
+ (7, -3, -2, 1),
+ (-8, -2, 4, 0),
+ (-7, -2, 3, -1),
+ (-6, -2, 3, 0),
+ (-5, -2, 2, -1),
+ (-4, -2, 2, 0),
+ (-3, -2, 1, -1),
+ (-2, -2, 1, 0),
+ (-1, -2, 0, -1),
+ (0, -2, 0, 0),
+ (1, -2, 0, 1),
+ (2, -2, -1, 0),
+ (3, -2, -1, 1),
+ (4, -2, -2, 0),
+ (5, -2, -2, 1),
+ (6, -2, -3, 0),
+ (7, -2, -3, 1),
+ (-8, -1, -8, 0), # overflows and wraps around
+ (-7, -1, 7, 0),
+ (-6, -1, 6, 0),
+ (-5, -1, 5, 0),
+ (-4, -1, 4, 0),
+ (-3, -1, 3, 0),
+ (-2, -1, 2, 0),
+ (-1, -1, 1, 0),
+ (0, -1, 0, 0),
+ (1, -1, -1, 0),
+ (2, -1, -2, 0),
+ (3, -1, -3, 0),
+ (4, -1, -4, 0),
+ (5, -1, -5, 0),
+ (6, -1, -6, 0),
+ (7, -1, -7, 0),
+ (-8, 0, -1, -8),
+ (-7, 0, -1, -7),
+ (-6, 0, -1, -6),
+ (-5, 0, -1, -5),
+ (-4, 0, -1, -4),
+ (-3, 0, -1, -3),
+ (-2, 0, -1, -2),
+ (-1, 0, -1, -1),
+ (0, 0, -1, 0),
+ (1, 0, -1, 1),
+ (2, 0, -1, 2),
+ (3, 0, -1, 3),
+ (4, 0, -1, 4),
+ (5, 0, -1, 5),
+ (6, 0, -1, 6),
+ (7, 0, -1, 7),
+ (-8, 1, -8, 0),
+ (-7, 1, -7, 0),
+ (-6, 1, -6, 0),
+ (-5, 1, -5, 0),
+ (-4, 1, -4, 0),
+ (-3, 1, -3, 0),
+ (-2, 1, -2, 0),
+ (-1, 1, -1, 0),
+ (0, 1, 0, 0),
+ (1, 1, 1, 0),
+ (2, 1, 2, 0),
+ (3, 1, 3, 0),
+ (4, 1, 4, 0),
+ (5, 1, 5, 0),
+ (6, 1, 6, 0),
+ (7, 1, 7, 0),
+ (-8, 2, -4, 0),
+ (-7, 2, -3, -1),
+ (-6, 2, -3, 0),
+ (-5, 2, -2, -1),
+ (-4, 2, -2, 0),
+ (-3, 2, -1, -1),
+ (-2, 2, -1, 0),
+ (-1, 2, 0, -1),
+ (0, 2, 0, 0),
+ (1, 2, 0, 1),
+ (2, 2, 1, 0),
+ (3, 2, 1, 1),
+ (4, 2, 2, 0),
+ (5, 2, 2, 1),
+ (6, 2, 3, 0),
+ (7, 2, 3, 1),
+ (-8, 3, -2, -2),
+ (-7, 3, -2, -1),
+ (-6, 3, -2, 0),
+ (-5, 3, -1, -2),
+ (-4, 3, -1, -1),
+ (-3, 3, -1, 0),
+ (-2, 3, 0, -2),
+ (-1, 3, 0, -1),
+ (0, 3, 0, 0),
+ (1, 3, 0, 1),
+ (2, 3, 0, 2),
+ (3, 3, 1, 0),
+ (4, 3, 1, 1),
+ (5, 3, 1, 2),
+ (6, 3, 2, 0),
+ (7, 3, 2, 1),
+ (-8, 4, -2, 0),
+ (-7, 4, -1, -3),
+ (-6, 4, -1, -2),
+ (-5, 4, -1, -1),
+ (-4, 4, -1, 0),
+ (-3, 4, 0, -3),
+ (-2, 4, 0, -2),
+ (-1, 4, 0, -1),
+ (0, 4, 0, 0),
+ (1, 4, 0, 1),
+ (2, 4, 0, 2),
+ (3, 4, 0, 3),
+ (4, 4, 1, 0),
+ (5, 4, 1, 1),
+ (6, 4, 1, 2),
+ (7, 4, 1, 3),
+ (-8, 5, -1, -3),
+ (-7, 5, -1, -2),
+ (-6, 5, -1, -1),
+ (-5, 5, -1, 0),
+ (-4, 5, 0, -4),
+ (-3, 5, 0, -3),
+ (-2, 5, 0, -2),
+ (-1, 5, 0, -1),
+ (0, 5, 0, 0),
+ (1, 5, 0, 1),
+ (2, 5, 0, 2),
+ (3, 5, 0, 3),
+ (4, 5, 0, 4),
+ (5, 5, 1, 0),
+ (6, 5, 1, 1),
+ (7, 5, 1, 2),
+ (-8, 6, -1, -2),
+ (-7, 6, -1, -1),
+ (-6, 6, -1, 0),
+ (-5, 6, 0, -5),
+ (-4, 6, 0, -4),
+ (-3, 6, 0, -3),
+ (-2, 6, 0, -2),
+ (-1, 6, 0, -1),
+ (0, 6, 0, 0),
+ (1, 6, 0, 1),
+ (2, 6, 0, 2),
+ (3, 6, 0, 3),
+ (4, 6, 0, 4),
+ (5, 6, 0, 5),
+ (6, 6, 1, 0),
+ (7, 6, 1, 1),
+ (-8, 7, -1, -1),
+ (-7, 7, -1, 0),
+ (-6, 7, 0, -6),
+ (-5, 7, 0, -5),
+ (-4, 7, 0, -4),
+ (-3, 7, 0, -3),
+ (-2, 7, 0, -2),
+ (-1, 7, 0, -1),
+ (0, 7, 0, 0),
+ (1, 7, 0, 1),
+ (2, 7, 0, 2),
+ (3, 7, 0, 3),
+ (4, 7, 0, 4),
+ (5, 7, 0, 5),
+ (6, 7, 0, 6),
+ (7, 7, 1, 0),
+ ]
+ for (n, d, q, r) in test_cases:
+ self.assertEqual(div_rem(n, d, 4, True), (q, r))
+
+ def test_unsigned(self):
+ for n in range(16):
+ for d in range(16):
+ if d == 0:
+ q = 16 - 1
+ r = n
+ else:
+ # div_rem matches // and % for unsigned integers
+ q = n // d
+ r = n % d
+ self.assertEqual(div_rem(n, d, 4, False), (q, r))
+
+
+class TestUnsignedDivRem(unittest.TestCase):
+ def helper(self, log2_radix):
+ bit_width = 4
+ for n in range(1 << bit_width):
+ for d in range(1 << bit_width):
+ q, r = div_rem(n, d, bit_width, False)
+ with self.subTest(n=n, d=d, q=q, r=r):
+ udr = UnsignedDivRem(n, d, bit_width, log2_radix)
+ for _ in range(250 * bit_width):
+ self.assertEqual(n, udr.quotient * udr.divisor
+ + udr.remainder)
+ if udr.calculate_stage():
+ break
+ else:
+ self.fail("infinite loop")
+ self.assertEqual(n, udr.quotient * udr.divisor
+ + udr.remainder)
+ self.assertEqual(udr.quotient, q)
+ self.assertEqual(udr.remainder, r)
+
+ def test_radix_2(self):
+ self.helper(1)
+
+ def test_radix_4(self):
+ self.helper(2)
+
+ def test_radix_8(self):
+ self.helper(3)
+
+ def test_radix_16(self):
+ self.helper(4)
+
+
+class TestDivRem(unittest.TestCase):
+ def helper(self, log2_radix):
+ bit_width = 4
+ for n in range(1 << bit_width):
+ for d in range(1 << bit_width):
+ for signed in False, True:
+ n = Const.normalize(n, (bit_width, signed))
+ d = Const.normalize(d, (bit_width, signed))
+ q, r = div_rem(n, d, bit_width, signed)
+ with self.subTest(n=n, d=d, q=q, r=r, signed=signed):
+ dr = DivRem(n, d, bit_width, signed, log2_radix)
+ for _ in range(250 * bit_width):
+ if dr.calculate_stage():
+ break
+ else:
+ self.fail("infinite loop")
+ self.assertEqual(dr.quotient, q)
+ self.assertEqual(dr.remainder, r)
+
+ def test_radix_2(self):
+ self.helper(1)
+
+ def test_radix_4(self):
+ self.helper(2)
+
+ def test_radix_8(self):
+ self.helper(3)
+
+ def test_radix_16(self):
+ self.helper(4)