2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
5 from nmigen
.hdl
.ast
import Const
6 from .algorithm
import (div_rem
, UnsignedDivRem
, DivRem
,
7 Fixed
, RootRemainder
, fixed_sqrt
, FixedSqrt
,
8 fixed_rsqrt
, FixedRSqrt
, Operation
,
14 class TestDivRemFn(unittest
.TestCase
):
15 def test_signed(self
):
17 # numerator, denominator, quotient, remainder
130 (-8, -1, -8, 0), # overflows and wraps around
275 for (n
, d
, q
, r
) in test_cases
:
276 self
.assertEqual(div_rem(n
, d
, 4, True), (q
, r
))
278 def test_unsigned(self
):
285 # div_rem matches // and % for unsigned integers
288 self
.assertEqual(div_rem(n
, d
, 4, False), (q
, r
))
291 class TestUnsignedDivRem(unittest
.TestCase
):
292 def helper(self
, log2_radix
):
294 for n
in range(1 << bit_width
):
295 for d
in range(1 << bit_width
):
296 q
, r
= div_rem(n
, d
, bit_width
, False)
297 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
):
298 udr
= UnsignedDivRem(n
, d
, bit_width
, log2_radix
)
299 for _
in range(250 * bit_width
):
300 self
.assertEqual(udr
.dividend
, n
)
301 self
.assertEqual(udr
.divisor
, d
)
302 self
.assertEqual(udr
.quotient_times_divisor
,
303 udr
.quotient
* udr
.divisor
)
304 self
.assertGreaterEqual(udr
.dividend
,
305 udr
.quotient_times_divisor
)
306 if udr
.calculate_stage():
309 self
.fail("infinite loop")
310 self
.assertEqual(udr
.dividend
, n
)
311 self
.assertEqual(udr
.divisor
, d
)
312 self
.assertEqual(udr
.quotient_times_divisor
,
313 udr
.quotient
* udr
.divisor
)
314 self
.assertGreaterEqual(udr
.dividend
,
315 udr
.quotient_times_divisor
)
316 self
.assertEqual(udr
.quotient
, q
)
317 self
.assertEqual(udr
.remainder
, r
)
319 def test_radix_2(self
):
322 def test_radix_4(self
):
325 def test_radix_8(self
):
328 def test_radix_16(self
):
332 class TestDivRem(unittest
.TestCase
):
333 def helper(self
, log2_radix
):
335 for n
in range(1 << bit_width
):
336 for d
in range(1 << bit_width
):
337 for signed
in False, True:
338 n
= Const
.normalize(n
, (bit_width
, signed
))
339 d
= Const
.normalize(d
, (bit_width
, signed
))
340 q
, r
= div_rem(n
, d
, bit_width
, signed
)
341 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
, signed
=signed
):
342 dr
= DivRem(n
, d
, bit_width
, signed
, log2_radix
)
343 for _
in range(250 * bit_width
):
344 if dr
.calculate_stage():
347 self
.fail("infinite loop")
348 self
.assertEqual(dr
.quotient
, q
)
349 self
.assertEqual(dr
.remainder
, r
)
351 def test_radix_2(self
):
354 def test_radix_4(self
):
357 def test_radix_8(self
):
360 def test_radix_16(self
):
364 class TestFixed(unittest
.TestCase
):
365 def test_constructor(self
):
366 value
= Fixed(0, 0, 1, False)
367 self
.assertEqual(value
.bits
, 0)
368 self
.assertEqual(value
.fract_width
, 0)
369 self
.assertEqual(value
.bit_width
, 1)
370 self
.assertEqual(value
.signed
, False)
371 value
= Fixed(1, 2, 3, True)
372 self
.assertEqual(value
.bits
, -4)
373 self
.assertEqual(value
.fract_width
, 2)
374 self
.assertEqual(value
.bit_width
, 3)
375 self
.assertEqual(value
.signed
, True)
376 value
= Fixed(1, 2, 4, True)
377 self
.assertEqual(value
.bits
, 4)
378 self
.assertEqual(value
.fract_width
, 2)
379 self
.assertEqual(value
.bit_width
, 4)
380 self
.assertEqual(value
.signed
, True)
381 value
= Fixed(1.25, 4, 8, True)
382 self
.assertEqual(value
.bits
, 0x14)
383 self
.assertEqual(value
.fract_width
, 4)
384 self
.assertEqual(value
.bit_width
, 8)
385 self
.assertEqual(value
.signed
, True)
386 value
= Fixed(Fixed(2, 0, 12, False), 4, 8, True)
387 self
.assertEqual(value
.bits
, 0x20)
388 self
.assertEqual(value
.fract_width
, 4)
389 self
.assertEqual(value
.bit_width
, 8)
390 self
.assertEqual(value
.signed
, True)
391 value
= Fixed(0x2FF / 2 ** 8, 8, 12, False)
392 self
.assertEqual(value
.bits
, 0x2FF)
393 self
.assertEqual(value
.fract_width
, 8)
394 self
.assertEqual(value
.bit_width
, 12)
395 self
.assertEqual(value
.signed
, False)
396 value
= Fixed(value
, 4, 8, True)
397 self
.assertEqual(value
.bits
, 0x2F)
398 self
.assertEqual(value
.fract_width
, 4)
399 self
.assertEqual(value
.bit_width
, 8)
400 self
.assertEqual(value
.signed
, True)
402 def helper_tst_from_bits(self
, bit_width
, fract_width
):
404 for bits
in range(1 << bit_width
):
405 with self
.subTest(bit_width
=bit_width
,
406 fract_width
=fract_width
,
409 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
410 self
.assertEqual(value
.bit_width
, bit_width
)
411 self
.assertEqual(value
.fract_width
, fract_width
)
412 self
.assertEqual(value
.signed
, signed
)
413 self
.assertEqual(value
.bits
, bits
)
415 for bits
in range(-1 << (bit_width
- 1), 1 << (bit_width
- 1)):
416 with self
.subTest(bit_width
=bit_width
,
417 fract_width
=fract_width
,
420 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
421 self
.assertEqual(value
.bit_width
, bit_width
)
422 self
.assertEqual(value
.fract_width
, fract_width
)
423 self
.assertEqual(value
.signed
, signed
)
424 self
.assertEqual(value
.bits
, bits
)
426 def test_from_bits(self
):
427 for bit_width
in range(1, 5):
428 for fract_width
in range(bit_width
):
429 self
.helper_tst_from_bits(bit_width
, fract_width
)
432 self
.assertEqual(repr(Fixed
.from_bits(1, 2, 3, False)),
433 "Fixed.from_bits(1, 2, 3, False)")
434 self
.assertEqual(repr(Fixed
.from_bits(-4, 2, 3, True)),
435 "Fixed.from_bits(-4, 2, 3, True)")
436 self
.assertEqual(repr(Fixed
.from_bits(-4, 7, 10, True)),
437 "Fixed.from_bits(-4, 7, 10, True)")
439 def test_trunc(self
):
440 for i
in range(-8, 8):
441 value
= Fixed
.from_bits(i
, 2, 4, True)
442 with self
.subTest(value
=repr(value
)):
443 self
.assertEqual(math
.trunc(value
), math
.trunc(i
/ 4))
446 for i
in range(-8, 8):
447 value
= Fixed
.from_bits(i
, 2, 4, True)
448 with self
.subTest(value
=repr(value
)):
449 self
.assertEqual(int(value
), math
.trunc(value
))
451 def test_float(self
):
452 for i
in range(-8, 8):
453 value
= Fixed
.from_bits(i
, 2, 4, True)
454 with self
.subTest(value
=repr(value
)):
455 self
.assertEqual(float(value
), i
/ 4)
457 def test_floor(self
):
458 for i
in range(-8, 8):
459 value
= Fixed
.from_bits(i
, 2, 4, True)
460 with self
.subTest(value
=repr(value
)):
461 self
.assertEqual(math
.floor(value
), math
.floor(i
/ 4))
464 for i
in range(-8, 8):
465 value
= Fixed
.from_bits(i
, 2, 4, True)
466 with self
.subTest(value
=repr(value
)):
467 self
.assertEqual(math
.ceil(value
), math
.ceil(i
/ 4))
470 for i
in range(-8, 8):
471 value
= Fixed
.from_bits(i
, 2, 4, True)
472 expected
= -i
/ 4 if i
!= -8 else -2.0 # handle wrap-around
473 with self
.subTest(value
=repr(value
)):
474 self
.assertEqual(float(-value
), expected
)
477 for i
in range(-8, 8):
478 value
= Fixed
.from_bits(i
, 2, 4, True)
479 with self
.subTest(value
=repr(value
)):
481 self
.assertEqual(value
.bits
, i
)
484 for i
in range(-8, 8):
485 value
= Fixed
.from_bits(i
, 2, 4, True)
486 expected
= abs(i
) / 4 if i
!= -8 else -2.0 # handle wrap-around
487 with self
.subTest(value
=repr(value
)):
488 self
.assertEqual(float(abs(value
)), expected
)
491 for i
in range(-8, 8):
492 value
= Fixed
.from_bits(i
, 2, 4, True)
493 with self
.subTest(value
=repr(value
)):
494 self
.assertEqual(float(~value
), (~i
) / 4)
497 def get_test_values(max_bit_width
, include_int
):
498 for signed
in False, True:
500 for bits
in range(1 << max_bit_width
):
501 int_value
= Const
.normalize(bits
, (max_bit_width
, signed
))
503 for bit_width
in range(1, max_bit_width
):
504 for fract_width
in range(bit_width
+ 1):
505 for bits
in range(1 << bit_width
):
506 yield Fixed
.from_bits(bits
,
511 def binary_op_test_helper(self
,
514 width_combine_op
=max,
515 adjust_bits_op
=None):
516 def default_adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
517 return bits
<< (out_fract_width
- in_fract_width
)
518 if adjust_bits_op
is None:
519 adjust_bits_op
= default_adjust_bits_op
521 for lhs
in self
.get_test_values(max_bit_width
, True):
522 lhs_is_int
= isinstance(lhs
, int)
523 for rhs
in self
.get_test_values(max_bit_width
, not lhs_is_int
):
524 rhs_is_int
= isinstance(rhs
, int)
526 assert not rhs_is_int
527 lhs_int
= adjust_bits_op(lhs
, rhs
.fract_width
, 0)
528 int_result
= operation(lhs_int
, rhs
.bits
)
530 expected
= Fixed
.from_bits(int_result
,
535 expected
= int_result
537 rhs_int
= adjust_bits_op(rhs
, lhs
.fract_width
, 0)
538 int_result
= operation(lhs
.bits
, rhs_int
)
540 expected
= Fixed
.from_bits(int_result
,
545 expected
= int_result
546 elif lhs
.signed
!= rhs
.signed
:
549 fract_width
= width_combine_op(lhs
.fract_width
,
551 int_width
= width_combine_op(lhs
.bit_width
555 bit_width
= fract_width
+ int_width
556 lhs_int
= adjust_bits_op(lhs
.bits
,
559 rhs_int
= adjust_bits_op(rhs
.bits
,
562 int_result
= operation(lhs_int
, rhs_int
)
564 expected
= Fixed
.from_bits(int_result
,
569 expected
= int_result
570 with self
.subTest(lhs
=repr(lhs
),
572 expected
=repr(expected
)):
573 result
= operation(lhs
, rhs
)
575 self
.assertEqual(result
.bit_width
, expected
.bit_width
)
576 self
.assertEqual(result
.signed
, expected
.signed
)
577 self
.assertEqual(result
.fract_width
,
578 expected
.fract_width
)
579 self
.assertEqual(result
.bits
, expected
.bits
)
581 self
.assertEqual(result
, expected
)
584 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
+ rhs
)
587 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
- rhs
)
590 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
& rhs
)
593 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs | rhs
)
596 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs ^ rhs
)
599 def adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
601 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
* rhs
,
603 lambda l_width
, r_width
: l_width
+ r_width
,
613 self
.binary_op_test_helper(cmp, False)
616 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
< rhs
, False)
619 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
<= rhs
, False)
622 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
== rhs
, False)
625 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
!= rhs
, False)
628 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
> rhs
, False)
631 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
>= rhs
, False)
634 for v
in self
.get_test_values(6, False):
635 with self
.subTest(v
=repr(v
)):
636 self
.assertEqual(bool(v
), bool(v
.bits
))
639 self
.assertEqual(str(Fixed
.from_bits(0x1234, 0, 16, False)),
641 self
.assertEqual(str(Fixed
.from_bits(-0x1234, 0, 16, True)),
643 self
.assertEqual(str(Fixed
.from_bits(0x12345, 3, 20, True)),
645 self
.assertEqual(str(Fixed(123.625, 3, 12, True)),
648 self
.assertEqual(str(Fixed
.from_bits(0x1, 0, 20, True)),
650 self
.assertEqual(str(Fixed
.from_bits(0x2, 1, 20, True)),
652 self
.assertEqual(str(Fixed
.from_bits(0x4, 2, 20, True)),
654 self
.assertEqual(str(Fixed
.from_bits(0x9, 3, 20, True)),
656 self
.assertEqual(str(Fixed
.from_bits(0x12, 4, 20, True)),
658 self
.assertEqual(str(Fixed
.from_bits(0x24, 5, 20, True)),
660 self
.assertEqual(str(Fixed
.from_bits(0x48, 6, 20, True)),
662 self
.assertEqual(str(Fixed
.from_bits(0x91, 7, 20, True)),
664 self
.assertEqual(str(Fixed
.from_bits(0x123, 8, 20, True)),
666 self
.assertEqual(str(Fixed
.from_bits(0x246, 9, 20, True)),
668 self
.assertEqual(str(Fixed
.from_bits(0x48d, 10, 20, True)),
670 self
.assertEqual(str(Fixed
.from_bits(0x91a, 11, 20, True)),
672 self
.assertEqual(str(Fixed
.from_bits(0x1234, 12, 20, True)),
674 self
.assertEqual(str(Fixed
.from_bits(0x2468, 13, 20, True)),
676 self
.assertEqual(str(Fixed
.from_bits(0x48d1, 14, 20, True)),
678 self
.assertEqual(str(Fixed
.from_bits(0x91a2, 15, 20, True)),
680 self
.assertEqual(str(Fixed
.from_bits(0x12345, 16, 20, True)),
682 self
.assertEqual(str(Fixed
.from_bits(0x2468a, 17, 20, True)),
684 self
.assertEqual(str(Fixed
.from_bits(0x48d14, 18, 20, True)),
686 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, True)),
688 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, False)),
692 class TestFixedSqrtFn(unittest
.TestCase
):
693 def test_on_ints(self
):
694 for radicand
in range(-1, 32):
698 root
= math
.floor(math
.sqrt(radicand
))
699 remainder
= radicand
- root
* root
700 expected
= RootRemainder(root
, remainder
)
701 with self
.subTest(radicand
=radicand
, expected
=expected
):
702 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
705 remainder
= radicand
- root
* root
706 expected
= RootRemainder(root
, remainder
)
707 with self
.subTest(radicand
=radicand
, expected
=expected
):
708 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
710 def test_on_fixed(self
):
711 for signed
in False, True:
712 for bit_width
in range(1, 10):
713 for fract_width
in range(bit_width
):
714 for bits
in range(1 << bit_width
):
715 radicand
= Fixed
.from_bits(bits
,
721 root
= radicand
.with_value(math
.sqrt(float(radicand
)))
722 remainder
= radicand
- root
* root
723 expected
= RootRemainder(root
, remainder
)
724 with self
.subTest(radicand
=repr(radicand
),
725 expected
=repr(expected
)):
726 self
.assertEqual(repr(fixed_sqrt(radicand
)),
729 def test_misc_cases(self
):
732 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
733 (Fixed(2, 30, 32, False),
734 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
736 for radicand
, expected
in test_cases
:
737 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
738 self
.assertEqual(str(fixed_sqrt(radicand
)), expected
)
741 class TestFixedSqrt(unittest
.TestCase
):
742 def helper(self
, log2_radix
):
743 for bit_width
in range(1, 8):
744 for fract_width
in range(bit_width
):
745 for radicand_bits
in range(1 << bit_width
):
746 radicand
= Fixed
.from_bits(radicand_bits
,
750 root_remainder
= fixed_sqrt(radicand
)
751 with self
.subTest(radicand
=repr(radicand
),
752 root_remainder
=repr(root_remainder
),
753 log2_radix
=log2_radix
):
754 obj
= FixedSqrt(radicand
, log2_radix
)
755 for _
in range(250 * bit_width
):
756 self
.assertEqual(obj
.root
* obj
.root
,
758 self
.assertGreaterEqual(obj
.radicand
,
760 if obj
.calculate_stage():
763 self
.fail("infinite loop")
764 self
.assertEqual(obj
.root
* obj
.root
,
766 self
.assertGreaterEqual(obj
.radicand
,
768 self
.assertEqual(obj
.remainder
,
769 obj
.radicand
- obj
.root_squared
)
770 self
.assertEqual(obj
.root
, root_remainder
.root
)
771 self
.assertEqual(obj
.remainder
,
772 root_remainder
.remainder
)
774 def test_radix_2(self
):
777 def test_radix_4(self
):
780 def test_radix_8(self
):
783 def test_radix_16(self
):
787 class TestFixedRSqrtFn(unittest
.TestCase
):
789 for bits
in range(1, 1 << 5):
790 radicand
= Fixed
.from_bits(bits
, 5, 12, False)
791 float_root
= 1 / math
.sqrt(float(radicand
))
792 root
= radicand
.with_value(float_root
)
793 remainder
= 1 - root
* root
* radicand
794 expected
= RootRemainder(root
, remainder
)
795 with self
.subTest(radicand
=repr(radicand
),
796 expected
=repr(expected
)):
797 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
801 for signed
in False, True:
802 for bit_width
in range(1, 10):
803 for fract_width
in range(bit_width
):
804 for bits
in range(1 << bit_width
):
805 radicand
= Fixed
.from_bits(bits
,
811 float_root
= 1 / math
.sqrt(float(radicand
))
812 max_value
= radicand
.with_bits(
813 (1 << (bit_width
- signed
)) - 1)
814 if float_root
> float(max_value
):
817 root
= radicand
.with_value(float_root
)
818 remainder
= 1 - root
* root
* radicand
819 expected
= RootRemainder(root
, remainder
)
820 with self
.subTest(radicand
=repr(radicand
),
821 expected
=repr(expected
)):
822 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
825 def test_misc_cases(self
):
828 (Fixed(0.5, 30, 32, False),
829 "RootRemainder(fixed:0x1.6a09e664, "
830 "fixed:0x0.0000000596d014780000000)")
832 for radicand
, expected
in test_cases
:
833 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
834 self
.assertEqual(str(fixed_rsqrt(radicand
)), expected
)
837 class TestFixedRSqrt(unittest
.TestCase
):
838 def helper(self
, log2_radix
):
839 for bit_width
in range(1, 8):
840 for fract_width
in range(bit_width
):
841 for radicand_bits
in range(1, 1 << bit_width
):
842 radicand
= Fixed
.from_bits(radicand_bits
,
846 root_remainder
= fixed_rsqrt(radicand
)
847 with self
.subTest(radicand
=repr(radicand
),
848 root_remainder
=repr(root_remainder
),
849 log2_radix
=log2_radix
):
850 obj
= FixedRSqrt(radicand
, log2_radix
)
851 for _
in range(250 * bit_width
):
852 self
.assertEqual(obj
.radicand
* obj
.root
,
854 self
.assertEqual(obj
.radicand_root
* obj
.root
,
855 obj
.radicand_root_squared
)
856 self
.assertGreaterEqual(1,
857 obj
.radicand_root_squared
)
858 if obj
.calculate_stage():
861 self
.fail("infinite loop")
862 self
.assertEqual(obj
.radicand
* obj
.root
,
864 self
.assertEqual(obj
.radicand_root
* obj
.root
,
865 obj
.radicand_root_squared
)
866 self
.assertGreaterEqual(1,
867 obj
.radicand_root_squared
)
868 self
.assertEqual(obj
.remainder
,
869 1 - obj
.radicand_root_squared
)
870 self
.assertEqual(obj
.root
, root_remainder
.root
)
871 self
.assertEqual(obj
.remainder
,
872 root_remainder
.remainder
)
874 def test_radix_2(self
):
877 def test_radix_4(self
):
880 def test_radix_8(self
):
883 def test_radix_16(self
):
887 class TestFixedUDivRemSqrtRSqrt(unittest
.TestCase
):
889 def show_fixed(bits
, fract_width
, bit_width
):
890 fixed
= Fixed
.from_bits(bits
, fract_width
, bit_width
, False)
891 return f
"{str(fixed)}:{repr(fixed)}"
893 def check_invariants(self
,
901 self
.assertEqual(obj
.dividend
, dividend
)
902 self
.assertEqual(obj
.divisor_radicand
, divisor_radicand
)
903 self
.assertEqual(obj
.operation
, operation
)
904 self
.assertEqual(obj
.bit_width
, bit_width
)
905 self
.assertEqual(obj
.fract_width
, fract_width
)
906 self
.assertEqual(obj
.log2_radix
, log2_radix
)
907 self
.assertEqual(obj
.root_times_radicand
,
908 obj
.quotient_root
* obj
.divisor_radicand
)
909 self
.assertGreaterEqual(obj
.compare_lhs
, obj
.compare_rhs
)
910 self
.assertEqual(obj
.remainder
, obj
.compare_lhs
- obj
.compare_rhs
)
911 if operation
is Operation
.UDivRem
:
912 self
.assertEqual(obj
.compare_lhs
, obj
.dividend
<< fract_width
)
913 self
.assertEqual(obj
.compare_rhs
,
914 (obj
.quotient_root
* obj
.divisor_radicand
)
916 elif operation
is Operation
.SqrtRem
:
917 self
.assertEqual(obj
.compare_lhs
,
918 obj
.divisor_radicand
<< (fract_width
* 2))
919 self
.assertEqual(obj
.compare_rhs
,
920 (obj
.quotient_root
* obj
.quotient_root
)
923 assert operation
is Operation
.RSqrtRem
924 self
.assertEqual(obj
.compare_lhs
,
925 1 << (fract_width
* 3))
926 self
.assertEqual(obj
.compare_rhs
,
927 obj
.quotient_root
* obj
.quotient_root
928 * obj
.divisor_radicand
)
930 def handle_case(self
,
937 dividend_str
= self
.show_fixed(dividend
,
939 bit_width
+ fract_width
)
940 divisor_radicand_str
= self
.show_fixed(divisor_radicand
,
943 with self
.subTest(dividend
=dividend_str
,
944 divisor_radicand
=divisor_radicand_str
,
945 operation
=operation
.name
,
947 fract_width
=fract_width
,
948 log2_radix
=log2_radix
):
949 if operation
is Operation
.UDivRem
:
950 if divisor_radicand
== 0:
952 quotient_root
, remainder
= div_rem(dividend
,
956 remainder
<<= fract_width
957 elif operation
is Operation
.SqrtRem
:
958 root_remainder
= fixed_sqrt(Fixed
.from_bits(divisor_radicand
,
962 self
.assertEqual(root_remainder
.root
.bit_width
,
964 self
.assertEqual(root_remainder
.root
.fract_width
,
966 self
.assertEqual(root_remainder
.remainder
.bit_width
,
968 self
.assertEqual(root_remainder
.remainder
.fract_width
,
970 quotient_root
= root_remainder
.root
.bits
971 remainder
= root_remainder
.remainder
.bits
<< fract_width
973 assert operation
is Operation
.RSqrtRem
974 if divisor_radicand
== 0:
976 root_remainder
= fixed_rsqrt(Fixed
.from_bits(divisor_radicand
,
980 self
.assertEqual(root_remainder
.root
.bit_width
,
982 self
.assertEqual(root_remainder
.root
.fract_width
,
984 self
.assertEqual(root_remainder
.remainder
.bit_width
,
986 self
.assertEqual(root_remainder
.remainder
.fract_width
,
988 quotient_root
= root_remainder
.root
.bits
989 remainder
= root_remainder
.remainder
.bits
990 if quotient_root
>= (1 << bit_width
):
992 quotient_root_str
= self
.show_fixed(quotient_root
,
995 remainder_str
= self
.show_fixed(remainder
,
998 with self
.subTest(quotient_root
=quotient_root_str
,
999 remainder
=remainder_str
):
1000 obj
= FixedUDivRemSqrtRSqrt(dividend
,
1006 for _
in range(250 * bit_width
):
1007 self
.check_invariants(dividend
,
1014 if obj
.calculate_stage():
1017 self
.fail("infinite loop")
1018 self
.check_invariants(dividend
,
1025 self
.assertEqual(obj
.quotient_root
, quotient_root
)
1026 self
.assertEqual(obj
.remainder
, remainder
)
1028 def helper(self
, log2_radix
, operation
):
1029 bit_width_range
= range(1, 8)
1030 if operation
is Operation
.UDivRem
:
1031 bit_width_range
= range(1, 6)
1032 for bit_width
in bit_width_range
:
1033 for fract_width
in range(bit_width
):
1034 for divisor_radicand
in range(1 << bit_width
):
1035 dividend_range
= range(1)
1036 if operation
is Operation
.UDivRem
:
1037 dividend_range
= range(1 << (bit_width
+ fract_width
))
1038 for dividend
in dividend_range
:
1039 self
.handle_case(dividend
,
1046 def test_radix_2_UDiv(self
):
1047 self
.helper(1, Operation
.UDivRem
)
1049 def test_radix_4_UDiv(self
):
1050 self
.helper(2, Operation
.UDivRem
)
1052 def test_radix_8_UDiv(self
):
1053 self
.helper(3, Operation
.UDivRem
)
1055 def test_radix_16_UDiv(self
):
1056 self
.helper(4, Operation
.UDivRem
)
1058 def test_radix_2_Sqrt(self
):
1059 self
.helper(1, Operation
.SqrtRem
)
1061 def test_radix_4_Sqrt(self
):
1062 self
.helper(2, Operation
.SqrtRem
)
1064 def test_radix_8_Sqrt(self
):
1065 self
.helper(3, Operation
.SqrtRem
)
1067 def test_radix_16_Sqrt(self
):
1068 self
.helper(4, Operation
.SqrtRem
)
1070 def test_radix_2_RSqrt(self
):
1071 self
.helper(1, Operation
.RSqrtRem
)
1073 def test_radix_4_RSqrt(self
):
1074 self
.helper(2, Operation
.RSqrtRem
)
1076 def test_radix_8_RSqrt(self
):
1077 self
.helper(3, Operation
.RSqrtRem
)
1079 def test_radix_16_RSqrt(self
):
1080 self
.helper(4, Operation
.RSqrtRem
)
1082 def test_int_div(self
):
1086 for dividend
in range(1 << bit_width
):
1087 for divisor
in range(1, 1 << bit_width
):
1088 obj
= FixedUDivRemSqrtRSqrt(dividend
,
1095 quotient
, remainder
= div_rem(dividend
,
1099 shifted_remainder
= remainder
<< fract_width
1100 with self
.subTest(dividend
=dividend
,
1103 remainder
=remainder
,
1104 shifted_remainder
=shifted_remainder
):
1105 self
.assertEqual(obj
.quotient_root
, quotient
)
1106 self
.assertEqual(obj
.remainder
, shifted_remainder
)
1108 def test_fract_div(self
):
1112 for dividend
in range(1 << bit_width
):
1113 for divisor
in range(1, 1 << bit_width
):
1114 obj
= FixedUDivRemSqrtRSqrt(dividend
<< fract_width
,
1121 quotient
= (dividend
<< fract_width
) // divisor
1122 if quotient
>= (1 << bit_width
):
1124 remainder
= (dividend
<< fract_width
) % divisor
1125 shifted_remainder
= remainder
<< fract_width
1126 with self
.subTest(dividend
=dividend
,
1129 remainder
=remainder
,
1130 shifted_remainder
=shifted_remainder
):
1131 self
.assertEqual(obj
.quotient_root
, quotient
)
1132 self
.assertEqual(obj
.remainder
, shifted_remainder
)
1135 if __name__
== '__main__':