1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 from nmigen
.hdl
.ast
import Const
5 from .algorithm
import (div_rem
, UnsignedDivRem
, DivRem
,
6 Fixed
, RootRemainder
, fixed_sqrt
, FixedSqrt
,
7 fixed_rsqrt
, FixedRSqrt
, Operation
,
13 class TestDivRemFn(unittest
.TestCase
):
14 def test_signed(self
):
16 # numerator, denominator, quotient, remainder
129 (-8, -1, -8, 0), # overflows and wraps around
274 for (n
, d
, q
, r
) in test_cases
:
275 self
.assertEqual(div_rem(n
, d
, 4, True), (q
, r
))
277 def test_unsigned(self
):
284 # div_rem matches // and % for unsigned integers
287 self
.assertEqual(div_rem(n
, d
, 4, False), (q
, r
))
290 class TestUnsignedDivRem(unittest
.TestCase
):
291 def helper(self
, log2_radix
):
293 for n
in range(1 << bit_width
):
294 for d
in range(1 << bit_width
):
295 q
, r
= div_rem(n
, d
, bit_width
, False)
296 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
):
297 udr
= UnsignedDivRem(n
, d
, bit_width
, log2_radix
)
298 for _
in range(250 * bit_width
):
299 self
.assertEqual(udr
.dividend
, n
)
300 self
.assertEqual(udr
.divisor
, d
)
301 self
.assertEqual(udr
.quotient_times_divisor
,
302 udr
.quotient
* udr
.divisor
)
303 self
.assertGreaterEqual(udr
.dividend
,
304 udr
.quotient_times_divisor
)
305 if udr
.calculate_stage():
308 self
.fail("infinite loop")
309 self
.assertEqual(udr
.dividend
, n
)
310 self
.assertEqual(udr
.divisor
, d
)
311 self
.assertEqual(udr
.quotient_times_divisor
,
312 udr
.quotient
* udr
.divisor
)
313 self
.assertGreaterEqual(udr
.dividend
,
314 udr
.quotient_times_divisor
)
315 self
.assertEqual(udr
.quotient
, q
)
316 self
.assertEqual(udr
.remainder
, r
)
318 def test_radix_2(self
):
321 def test_radix_4(self
):
324 def test_radix_8(self
):
327 def test_radix_16(self
):
331 class TestDivRem(unittest
.TestCase
):
332 def helper(self
, log2_radix
):
334 for n
in range(1 << bit_width
):
335 for d
in range(1 << bit_width
):
336 for signed
in False, True:
337 n
= Const
.normalize(n
, (bit_width
, signed
))
338 d
= Const
.normalize(d
, (bit_width
, signed
))
339 q
, r
= div_rem(n
, d
, bit_width
, signed
)
340 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
, signed
=signed
):
341 dr
= DivRem(n
, d
, bit_width
, signed
, log2_radix
)
342 for _
in range(250 * bit_width
):
343 if dr
.calculate_stage():
346 self
.fail("infinite loop")
347 self
.assertEqual(dr
.quotient
, q
)
348 self
.assertEqual(dr
.remainder
, r
)
350 def test_radix_2(self
):
353 def test_radix_4(self
):
356 def test_radix_8(self
):
359 def test_radix_16(self
):
363 class TestFixed(unittest
.TestCase
):
364 def test_constructor(self
):
365 value
= Fixed(0, 0, 1, False)
366 self
.assertEqual(value
.bits
, 0)
367 self
.assertEqual(value
.fract_width
, 0)
368 self
.assertEqual(value
.bit_width
, 1)
369 self
.assertEqual(value
.signed
, False)
370 value
= Fixed(1, 2, 3, True)
371 self
.assertEqual(value
.bits
, -4)
372 self
.assertEqual(value
.fract_width
, 2)
373 self
.assertEqual(value
.bit_width
, 3)
374 self
.assertEqual(value
.signed
, True)
375 value
= Fixed(1, 2, 4, True)
376 self
.assertEqual(value
.bits
, 4)
377 self
.assertEqual(value
.fract_width
, 2)
378 self
.assertEqual(value
.bit_width
, 4)
379 self
.assertEqual(value
.signed
, True)
380 value
= Fixed(1.25, 4, 8, True)
381 self
.assertEqual(value
.bits
, 0x14)
382 self
.assertEqual(value
.fract_width
, 4)
383 self
.assertEqual(value
.bit_width
, 8)
384 self
.assertEqual(value
.signed
, True)
385 value
= Fixed(Fixed(2, 0, 12, False), 4, 8, True)
386 self
.assertEqual(value
.bits
, 0x20)
387 self
.assertEqual(value
.fract_width
, 4)
388 self
.assertEqual(value
.bit_width
, 8)
389 self
.assertEqual(value
.signed
, True)
390 value
= Fixed(0x2FF / 2 ** 8, 8, 12, False)
391 self
.assertEqual(value
.bits
, 0x2FF)
392 self
.assertEqual(value
.fract_width
, 8)
393 self
.assertEqual(value
.bit_width
, 12)
394 self
.assertEqual(value
.signed
, False)
395 value
= Fixed(value
, 4, 8, True)
396 self
.assertEqual(value
.bits
, 0x2F)
397 self
.assertEqual(value
.fract_width
, 4)
398 self
.assertEqual(value
.bit_width
, 8)
399 self
.assertEqual(value
.signed
, True)
401 def helper_tst_from_bits(self
, bit_width
, fract_width
):
403 for bits
in range(1 << bit_width
):
404 with self
.subTest(bit_width
=bit_width
,
405 fract_width
=fract_width
,
408 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
409 self
.assertEqual(value
.bit_width
, bit_width
)
410 self
.assertEqual(value
.fract_width
, fract_width
)
411 self
.assertEqual(value
.signed
, signed
)
412 self
.assertEqual(value
.bits
, bits
)
414 for bits
in range(-1 << (bit_width
- 1), 1 << (bit_width
- 1)):
415 with self
.subTest(bit_width
=bit_width
,
416 fract_width
=fract_width
,
419 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
420 self
.assertEqual(value
.bit_width
, bit_width
)
421 self
.assertEqual(value
.fract_width
, fract_width
)
422 self
.assertEqual(value
.signed
, signed
)
423 self
.assertEqual(value
.bits
, bits
)
425 def test_from_bits(self
):
426 for bit_width
in range(1, 5):
427 for fract_width
in range(bit_width
):
428 self
.helper_tst_from_bits(bit_width
, fract_width
)
431 self
.assertEqual(repr(Fixed
.from_bits(1, 2, 3, False)),
432 "Fixed.from_bits(1, 2, 3, False)")
433 self
.assertEqual(repr(Fixed
.from_bits(-4, 2, 3, True)),
434 "Fixed.from_bits(-4, 2, 3, True)")
435 self
.assertEqual(repr(Fixed
.from_bits(-4, 7, 10, True)),
436 "Fixed.from_bits(-4, 7, 10, True)")
438 def test_trunc(self
):
439 for i
in range(-8, 8):
440 value
= Fixed
.from_bits(i
, 2, 4, True)
441 with self
.subTest(value
=repr(value
)):
442 self
.assertEqual(math
.trunc(value
), math
.trunc(i
/ 4))
445 for i
in range(-8, 8):
446 value
= Fixed
.from_bits(i
, 2, 4, True)
447 with self
.subTest(value
=repr(value
)):
448 self
.assertEqual(int(value
), math
.trunc(value
))
450 def test_float(self
):
451 for i
in range(-8, 8):
452 value
= Fixed
.from_bits(i
, 2, 4, True)
453 with self
.subTest(value
=repr(value
)):
454 self
.assertEqual(float(value
), i
/ 4)
456 def test_floor(self
):
457 for i
in range(-8, 8):
458 value
= Fixed
.from_bits(i
, 2, 4, True)
459 with self
.subTest(value
=repr(value
)):
460 self
.assertEqual(math
.floor(value
), math
.floor(i
/ 4))
463 for i
in range(-8, 8):
464 value
= Fixed
.from_bits(i
, 2, 4, True)
465 with self
.subTest(value
=repr(value
)):
466 self
.assertEqual(math
.ceil(value
), math
.ceil(i
/ 4))
469 for i
in range(-8, 8):
470 value
= Fixed
.from_bits(i
, 2, 4, True)
471 expected
= -i
/ 4 if i
!= -8 else -2.0 # handle wrap-around
472 with self
.subTest(value
=repr(value
)):
473 self
.assertEqual(float(-value
), expected
)
476 for i
in range(-8, 8):
477 value
= Fixed
.from_bits(i
, 2, 4, True)
478 with self
.subTest(value
=repr(value
)):
480 self
.assertEqual(value
.bits
, i
)
483 for i
in range(-8, 8):
484 value
= Fixed
.from_bits(i
, 2, 4, True)
485 expected
= abs(i
) / 4 if i
!= -8 else -2.0 # handle wrap-around
486 with self
.subTest(value
=repr(value
)):
487 self
.assertEqual(float(abs(value
)), expected
)
490 for i
in range(-8, 8):
491 value
= Fixed
.from_bits(i
, 2, 4, True)
492 with self
.subTest(value
=repr(value
)):
493 self
.assertEqual(float(~value
), (~i
) / 4)
496 def get_test_values(max_bit_width
, include_int
):
497 for signed
in False, True:
499 for bits
in range(1 << max_bit_width
):
500 int_value
= Const
.normalize(bits
, (max_bit_width
, signed
))
502 for bit_width
in range(1, max_bit_width
):
503 for fract_width
in range(bit_width
+ 1):
504 for bits
in range(1 << bit_width
):
505 yield Fixed
.from_bits(bits
,
510 def binary_op_test_helper(self
,
513 width_combine_op
=max,
514 adjust_bits_op
=None):
515 def default_adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
516 return bits
<< (out_fract_width
- in_fract_width
)
517 if adjust_bits_op
is None:
518 adjust_bits_op
= default_adjust_bits_op
520 for lhs
in self
.get_test_values(max_bit_width
, True):
521 lhs_is_int
= isinstance(lhs
, int)
522 for rhs
in self
.get_test_values(max_bit_width
, not lhs_is_int
):
523 rhs_is_int
= isinstance(rhs
, int)
525 assert not rhs_is_int
526 lhs_int
= adjust_bits_op(lhs
, rhs
.fract_width
, 0)
527 int_result
= operation(lhs_int
, rhs
.bits
)
529 expected
= Fixed
.from_bits(int_result
,
534 expected
= int_result
536 rhs_int
= adjust_bits_op(rhs
, lhs
.fract_width
, 0)
537 int_result
= operation(lhs
.bits
, rhs_int
)
539 expected
= Fixed
.from_bits(int_result
,
544 expected
= int_result
545 elif lhs
.signed
!= rhs
.signed
:
548 fract_width
= width_combine_op(lhs
.fract_width
,
550 int_width
= width_combine_op(lhs
.bit_width
554 bit_width
= fract_width
+ int_width
555 lhs_int
= adjust_bits_op(lhs
.bits
,
558 rhs_int
= adjust_bits_op(rhs
.bits
,
561 int_result
= operation(lhs_int
, rhs_int
)
563 expected
= Fixed
.from_bits(int_result
,
568 expected
= int_result
569 with self
.subTest(lhs
=repr(lhs
),
571 expected
=repr(expected
)):
572 result
= operation(lhs
, rhs
)
574 self
.assertEqual(result
.bit_width
, expected
.bit_width
)
575 self
.assertEqual(result
.signed
, expected
.signed
)
576 self
.assertEqual(result
.fract_width
,
577 expected
.fract_width
)
578 self
.assertEqual(result
.bits
, expected
.bits
)
580 self
.assertEqual(result
, expected
)
583 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
+ rhs
)
586 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
- rhs
)
589 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
& rhs
)
592 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs | rhs
)
595 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs ^ rhs
)
598 def adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
600 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
* rhs
,
602 lambda l_width
, r_width
: l_width
+ r_width
,
612 self
.binary_op_test_helper(cmp, False)
615 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
< rhs
, False)
618 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
<= rhs
, False)
621 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
== rhs
, False)
624 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
!= rhs
, False)
627 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
> rhs
, False)
630 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
>= rhs
, False)
633 for v
in self
.get_test_values(6, False):
634 with self
.subTest(v
=repr(v
)):
635 self
.assertEqual(bool(v
), bool(v
.bits
))
638 self
.assertEqual(str(Fixed
.from_bits(0x1234, 0, 16, False)),
640 self
.assertEqual(str(Fixed
.from_bits(-0x1234, 0, 16, True)),
642 self
.assertEqual(str(Fixed
.from_bits(0x12345, 3, 20, True)),
644 self
.assertEqual(str(Fixed(123.625, 3, 12, True)),
647 self
.assertEqual(str(Fixed
.from_bits(0x1, 0, 20, True)),
649 self
.assertEqual(str(Fixed
.from_bits(0x2, 1, 20, True)),
651 self
.assertEqual(str(Fixed
.from_bits(0x4, 2, 20, True)),
653 self
.assertEqual(str(Fixed
.from_bits(0x9, 3, 20, True)),
655 self
.assertEqual(str(Fixed
.from_bits(0x12, 4, 20, True)),
657 self
.assertEqual(str(Fixed
.from_bits(0x24, 5, 20, True)),
659 self
.assertEqual(str(Fixed
.from_bits(0x48, 6, 20, True)),
661 self
.assertEqual(str(Fixed
.from_bits(0x91, 7, 20, True)),
663 self
.assertEqual(str(Fixed
.from_bits(0x123, 8, 20, True)),
665 self
.assertEqual(str(Fixed
.from_bits(0x246, 9, 20, True)),
667 self
.assertEqual(str(Fixed
.from_bits(0x48d, 10, 20, True)),
669 self
.assertEqual(str(Fixed
.from_bits(0x91a, 11, 20, True)),
671 self
.assertEqual(str(Fixed
.from_bits(0x1234, 12, 20, True)),
673 self
.assertEqual(str(Fixed
.from_bits(0x2468, 13, 20, True)),
675 self
.assertEqual(str(Fixed
.from_bits(0x48d1, 14, 20, True)),
677 self
.assertEqual(str(Fixed
.from_bits(0x91a2, 15, 20, True)),
679 self
.assertEqual(str(Fixed
.from_bits(0x12345, 16, 20, True)),
681 self
.assertEqual(str(Fixed
.from_bits(0x2468a, 17, 20, True)),
683 self
.assertEqual(str(Fixed
.from_bits(0x48d14, 18, 20, True)),
685 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, True)),
687 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, False)),
691 class TestFixedSqrtFn(unittest
.TestCase
):
692 def test_on_ints(self
):
693 for radicand
in range(-1, 32):
697 root
= math
.floor(math
.sqrt(radicand
))
698 remainder
= radicand
- root
* root
699 expected
= RootRemainder(root
, remainder
)
700 with self
.subTest(radicand
=radicand
, expected
=expected
):
701 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
704 remainder
= radicand
- root
* root
705 expected
= RootRemainder(root
, remainder
)
706 with self
.subTest(radicand
=radicand
, expected
=expected
):
707 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
709 def test_on_fixed(self
):
710 for signed
in False, True:
711 for bit_width
in range(1, 10):
712 for fract_width
in range(bit_width
):
713 for bits
in range(1 << bit_width
):
714 radicand
= Fixed
.from_bits(bits
,
720 root
= radicand
.with_value(math
.sqrt(float(radicand
)))
721 remainder
= radicand
- root
* root
722 expected
= RootRemainder(root
, remainder
)
723 with self
.subTest(radicand
=repr(radicand
),
724 expected
=repr(expected
)):
725 self
.assertEqual(repr(fixed_sqrt(radicand
)),
728 def test_misc_cases(self
):
731 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
732 (Fixed(2, 30, 32, False),
733 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
735 for radicand
, expected
in test_cases
:
736 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
737 self
.assertEqual(str(fixed_sqrt(radicand
)), expected
)
740 class TestFixedSqrt(unittest
.TestCase
):
741 def helper(self
, log2_radix
):
742 for bit_width
in range(1, 8):
743 for fract_width
in range(bit_width
):
744 for radicand_bits
in range(1 << bit_width
):
745 radicand
= Fixed
.from_bits(radicand_bits
,
749 root_remainder
= fixed_sqrt(radicand
)
750 with self
.subTest(radicand
=repr(radicand
),
751 root_remainder
=repr(root_remainder
),
752 log2_radix
=log2_radix
):
753 obj
= FixedSqrt(radicand
, log2_radix
)
754 for _
in range(250 * bit_width
):
755 self
.assertEqual(obj
.root
* obj
.root
,
757 self
.assertGreaterEqual(obj
.radicand
,
759 if obj
.calculate_stage():
762 self
.fail("infinite loop")
763 self
.assertEqual(obj
.root
* obj
.root
,
765 self
.assertGreaterEqual(obj
.radicand
,
767 self
.assertEqual(obj
.remainder
,
768 obj
.radicand
- obj
.root_squared
)
769 self
.assertEqual(obj
.root
, root_remainder
.root
)
770 self
.assertEqual(obj
.remainder
,
771 root_remainder
.remainder
)
773 def test_radix_2(self
):
776 def test_radix_4(self
):
779 def test_radix_8(self
):
782 def test_radix_16(self
):
786 class TestFixedRSqrtFn(unittest
.TestCase
):
788 for bits
in range(1, 1 << 5):
789 radicand
= Fixed
.from_bits(bits
, 5, 12, False)
790 float_root
= 1 / math
.sqrt(float(radicand
))
791 root
= radicand
.with_value(float_root
)
792 remainder
= 1 - root
* root
* radicand
793 expected
= RootRemainder(root
, remainder
)
794 with self
.subTest(radicand
=repr(radicand
),
795 expected
=repr(expected
)):
796 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
800 for signed
in False, True:
801 for bit_width
in range(1, 10):
802 for fract_width
in range(bit_width
):
803 for bits
in range(1 << bit_width
):
804 radicand
= Fixed
.from_bits(bits
,
810 float_root
= 1 / math
.sqrt(float(radicand
))
811 max_value
= radicand
.with_bits(
812 (1 << (bit_width
- signed
)) - 1)
813 if float_root
> float(max_value
):
816 root
= radicand
.with_value(float_root
)
817 remainder
= 1 - root
* root
* radicand
818 expected
= RootRemainder(root
, remainder
)
819 with self
.subTest(radicand
=repr(radicand
),
820 expected
=repr(expected
)):
821 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
824 def test_misc_cases(self
):
827 (Fixed(0.5, 30, 32, False),
828 "RootRemainder(fixed:0x1.6a09e664, "
829 "fixed:0x0.0000000596d014780000000)")
831 for radicand
, expected
in test_cases
:
832 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
833 self
.assertEqual(str(fixed_rsqrt(radicand
)), expected
)
836 class TestFixedRSqrt(unittest
.TestCase
):
837 def helper(self
, log2_radix
):
838 for bit_width
in range(1, 8):
839 for fract_width
in range(bit_width
):
840 for radicand_bits
in range(1, 1 << bit_width
):
841 radicand
= Fixed
.from_bits(radicand_bits
,
845 root_remainder
= fixed_rsqrt(radicand
)
846 with self
.subTest(radicand
=repr(radicand
),
847 root_remainder
=repr(root_remainder
),
848 log2_radix
=log2_radix
):
849 obj
= FixedRSqrt(radicand
, log2_radix
)
850 for _
in range(250 * bit_width
):
851 self
.assertEqual(obj
.radicand
* obj
.root
,
853 self
.assertEqual(obj
.radicand_root
* obj
.root
,
854 obj
.radicand_root_squared
)
855 self
.assertGreaterEqual(1,
856 obj
.radicand_root_squared
)
857 if obj
.calculate_stage():
860 self
.fail("infinite loop")
861 self
.assertEqual(obj
.radicand
* obj
.root
,
863 self
.assertEqual(obj
.radicand_root
* obj
.root
,
864 obj
.radicand_root_squared
)
865 self
.assertGreaterEqual(1,
866 obj
.radicand_root_squared
)
867 self
.assertEqual(obj
.remainder
,
868 1 - obj
.radicand_root_squared
)
869 self
.assertEqual(obj
.root
, root_remainder
.root
)
870 self
.assertEqual(obj
.remainder
,
871 root_remainder
.remainder
)
873 def test_radix_2(self
):
876 def test_radix_4(self
):
879 def test_radix_8(self
):
882 def test_radix_16(self
):
886 class TestFixedUDivRemSqrtRSqrt(unittest
.TestCase
):
888 def show_fixed(bits
, fract_width
, bit_width
):
889 fixed
= Fixed
.from_bits(bits
, fract_width
, bit_width
, False)
890 return f
"{str(fixed)}:{repr(fixed)}"
892 def check_invariants(self
,
900 self
.assertEqual(obj
.dividend
, dividend
)
901 self
.assertEqual(obj
.divisor_radicand
, divisor_radicand
)
902 self
.assertEqual(obj
.operation
, operation
)
903 self
.assertEqual(obj
.bit_width
, bit_width
)
904 self
.assertEqual(obj
.fract_width
, fract_width
)
905 self
.assertEqual(obj
.log2_radix
, log2_radix
)
906 self
.assertEqual(obj
.root_times_radicand
,
907 obj
.quotient_root
* obj
.divisor_radicand
)
908 self
.assertGreaterEqual(obj
.compare_lhs
, obj
.compare_rhs
)
909 self
.assertEqual(obj
.remainder
, obj
.compare_lhs
- obj
.compare_rhs
)
910 if operation
is Operation
.UDivRem
:
911 self
.assertEqual(obj
.compare_lhs
, obj
.dividend
<< fract_width
)
912 self
.assertEqual(obj
.compare_rhs
,
913 (obj
.quotient_root
* obj
.divisor_radicand
)
915 elif operation
is Operation
.SqrtRem
:
916 self
.assertEqual(obj
.compare_lhs
,
917 obj
.divisor_radicand
<< (fract_width
* 2))
918 self
.assertEqual(obj
.compare_rhs
,
919 (obj
.quotient_root
* obj
.quotient_root
)
922 assert operation
is Operation
.RSqrtRem
923 self
.assertEqual(obj
.compare_lhs
,
924 1 << (fract_width
* 3))
925 self
.assertEqual(obj
.compare_rhs
,
926 obj
.quotient_root
* obj
.quotient_root
927 * obj
.divisor_radicand
)
929 def handle_case(self
,
936 dividend_str
= self
.show_fixed(dividend
,
938 bit_width
+ fract_width
)
939 divisor_radicand_str
= self
.show_fixed(divisor_radicand
,
942 with self
.subTest(dividend
=dividend_str
,
943 divisor_radicand
=divisor_radicand_str
,
944 operation
=operation
.name
,
946 fract_width
=fract_width
,
947 log2_radix
=log2_radix
):
948 if operation
is Operation
.UDivRem
:
949 if divisor_radicand
== 0:
951 quotient_root
, remainder
= div_rem(dividend
,
955 remainder
<<= fract_width
956 elif operation
is Operation
.SqrtRem
:
957 root_remainder
= fixed_sqrt(Fixed
.from_bits(divisor_radicand
,
961 self
.assertEqual(root_remainder
.root
.bit_width
,
963 self
.assertEqual(root_remainder
.root
.fract_width
,
965 self
.assertEqual(root_remainder
.remainder
.bit_width
,
967 self
.assertEqual(root_remainder
.remainder
.fract_width
,
969 quotient_root
= root_remainder
.root
.bits
970 remainder
= root_remainder
.remainder
.bits
<< fract_width
972 assert operation
is Operation
.RSqrtRem
973 if divisor_radicand
== 0:
975 root_remainder
= fixed_rsqrt(Fixed
.from_bits(divisor_radicand
,
979 self
.assertEqual(root_remainder
.root
.bit_width
,
981 self
.assertEqual(root_remainder
.root
.fract_width
,
983 self
.assertEqual(root_remainder
.remainder
.bit_width
,
985 self
.assertEqual(root_remainder
.remainder
.fract_width
,
987 quotient_root
= root_remainder
.root
.bits
988 remainder
= root_remainder
.remainder
.bits
989 if quotient_root
>= (1 << bit_width
):
991 quotient_root_str
= self
.show_fixed(quotient_root
,
994 remainder_str
= self
.show_fixed(remainder
,
997 with self
.subTest(quotient_root
=quotient_root_str
,
998 remainder
=remainder_str
):
999 obj
= FixedUDivRemSqrtRSqrt(dividend
,
1005 for _
in range(250 * bit_width
):
1006 self
.check_invariants(dividend
,
1013 if obj
.calculate_stage():
1016 self
.fail("infinite loop")
1017 self
.check_invariants(dividend
,
1024 self
.assertEqual(obj
.quotient_root
, quotient_root
)
1025 self
.assertEqual(obj
.remainder
, remainder
)
1027 def helper(self
, log2_radix
, operation
):
1028 bit_width_range
= range(1, 8)
1029 if operation
is Operation
.UDivRem
:
1030 bit_width_range
= range(1, 6)
1031 for bit_width
in bit_width_range
:
1032 for fract_width
in range(bit_width
):
1033 for divisor_radicand
in range(1 << bit_width
):
1034 dividend_range
= range(1)
1035 if operation
is Operation
.UDivRem
:
1036 dividend_range
= range(1 << (bit_width
+ fract_width
))
1037 for dividend
in dividend_range
:
1038 self
.handle_case(dividend
,
1045 def test_radix_2_UDiv(self
):
1046 self
.helper(1, Operation
.UDivRem
)
1048 def test_radix_4_UDiv(self
):
1049 self
.helper(2, Operation
.UDivRem
)
1051 def test_radix_8_UDiv(self
):
1052 self
.helper(3, Operation
.UDivRem
)
1054 def test_radix_16_UDiv(self
):
1055 self
.helper(4, Operation
.UDivRem
)
1057 def test_radix_2_Sqrt(self
):
1058 self
.helper(1, Operation
.SqrtRem
)
1060 def test_radix_4_Sqrt(self
):
1061 self
.helper(2, Operation
.SqrtRem
)
1063 def test_radix_8_Sqrt(self
):
1064 self
.helper(3, Operation
.SqrtRem
)
1066 def test_radix_16_Sqrt(self
):
1067 self
.helper(4, Operation
.SqrtRem
)
1069 def test_radix_2_RSqrt(self
):
1070 self
.helper(1, Operation
.RSqrtRem
)
1072 def test_radix_4_RSqrt(self
):
1073 self
.helper(2, Operation
.RSqrtRem
)
1075 def test_radix_8_RSqrt(self
):
1076 self
.helper(3, Operation
.RSqrtRem
)
1078 def test_radix_16_RSqrt(self
):
1079 self
.helper(4, Operation
.RSqrtRem
)
1081 def test_int_div(self
):
1085 for dividend
in range(1 << bit_width
):
1086 for divisor
in range(1, 1 << bit_width
):
1087 obj
= FixedUDivRemSqrtRSqrt(dividend
,
1094 quotient
, remainder
= div_rem(dividend
,
1098 shifted_remainder
= remainder
<< fract_width
1099 with self
.subTest(dividend
=dividend
,
1102 remainder
=remainder
,
1103 shifted_remainder
=shifted_remainder
):
1104 self
.assertEqual(obj
.quotient_root
, quotient
)
1105 self
.assertEqual(obj
.remainder
, shifted_remainder
)
1107 def test_fract_div(self
):
1111 for dividend
in range(1 << bit_width
):
1112 for divisor
in range(1, 1 << bit_width
):
1113 obj
= FixedUDivRemSqrtRSqrt(dividend
<< fract_width
,
1120 quotient
= (dividend
<< fract_width
) // divisor
1121 if quotient
>= (1 << bit_width
):
1123 remainder
= (dividend
<< fract_width
) % divisor
1124 shifted_remainder
= remainder
<< fract_width
1125 with self
.subTest(dividend
=dividend
,
1128 remainder
=remainder
,
1129 shifted_remainder
=shifted_remainder
):
1130 self
.assertEqual(obj
.quotient_root
, quotient
)
1131 self
.assertEqual(obj
.remainder
, shifted_remainder
)