1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 from nmigen
.hdl
.ast
import Const
5 from .algorithm
import (div_rem
, UnsignedDivRem
, DivRem
,
6 Fixed
, RootRemainder
, fixed_sqrt
, FixedSqrt
,
7 fixed_rsqrt
, FixedRSqrt
, Operation
,
13 class TestDivRemFn(unittest
.TestCase
):
14 def test_signed(self
):
16 # numerator, denominator, quotient, remainder
129 (-8, -1, -8, 0), # overflows and wraps around
274 for (n
, d
, q
, r
) in test_cases
:
275 self
.assertEqual(div_rem(n
, d
, 4, True), (q
, r
))
277 def test_unsigned(self
):
284 # div_rem matches // and % for unsigned integers
287 self
.assertEqual(div_rem(n
, d
, 4, False), (q
, r
))
290 class TestUnsignedDivRem(unittest
.TestCase
):
291 def helper(self
, log2_radix
):
293 for n
in range(1 << bit_width
):
294 for d
in range(1 << bit_width
):
295 q
, r
= div_rem(n
, d
, bit_width
, False)
296 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
):
297 udr
= UnsignedDivRem(n
, d
, bit_width
, log2_radix
)
298 for _
in range(250 * bit_width
):
299 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
301 if udr
.calculate_stage():
304 self
.fail("infinite loop")
305 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
307 self
.assertEqual(udr
.quotient
, q
)
308 self
.assertEqual(udr
.remainder
, r
)
310 def test_radix_2(self
):
313 def test_radix_4(self
):
316 def test_radix_8(self
):
319 def test_radix_16(self
):
323 class TestDivRem(unittest
.TestCase
):
324 def helper(self
, log2_radix
):
326 for n
in range(1 << bit_width
):
327 for d
in range(1 << bit_width
):
328 for signed
in False, True:
329 n
= Const
.normalize(n
, (bit_width
, signed
))
330 d
= Const
.normalize(d
, (bit_width
, signed
))
331 q
, r
= div_rem(n
, d
, bit_width
, signed
)
332 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
, signed
=signed
):
333 dr
= DivRem(n
, d
, bit_width
, signed
, log2_radix
)
334 for _
in range(250 * bit_width
):
335 if dr
.calculate_stage():
338 self
.fail("infinite loop")
339 self
.assertEqual(dr
.quotient
, q
)
340 self
.assertEqual(dr
.remainder
, r
)
342 def test_radix_2(self
):
345 def test_radix_4(self
):
348 def test_radix_8(self
):
351 def test_radix_16(self
):
355 class TestFixed(unittest
.TestCase
):
356 def test_constructor(self
):
357 value
= Fixed(0, 0, 1, False)
358 self
.assertEqual(value
.bits
, 0)
359 self
.assertEqual(value
.fract_width
, 0)
360 self
.assertEqual(value
.bit_width
, 1)
361 self
.assertEqual(value
.signed
, False)
362 value
= Fixed(1, 2, 3, True)
363 self
.assertEqual(value
.bits
, -4)
364 self
.assertEqual(value
.fract_width
, 2)
365 self
.assertEqual(value
.bit_width
, 3)
366 self
.assertEqual(value
.signed
, True)
367 value
= Fixed(1, 2, 4, True)
368 self
.assertEqual(value
.bits
, 4)
369 self
.assertEqual(value
.fract_width
, 2)
370 self
.assertEqual(value
.bit_width
, 4)
371 self
.assertEqual(value
.signed
, True)
372 value
= Fixed(1.25, 4, 8, True)
373 self
.assertEqual(value
.bits
, 0x14)
374 self
.assertEqual(value
.fract_width
, 4)
375 self
.assertEqual(value
.bit_width
, 8)
376 self
.assertEqual(value
.signed
, True)
377 value
= Fixed(Fixed(2, 0, 12, False), 4, 8, True)
378 self
.assertEqual(value
.bits
, 0x20)
379 self
.assertEqual(value
.fract_width
, 4)
380 self
.assertEqual(value
.bit_width
, 8)
381 self
.assertEqual(value
.signed
, True)
382 value
= Fixed(0x2FF / 2 ** 8, 8, 12, False)
383 self
.assertEqual(value
.bits
, 0x2FF)
384 self
.assertEqual(value
.fract_width
, 8)
385 self
.assertEqual(value
.bit_width
, 12)
386 self
.assertEqual(value
.signed
, False)
387 value
= Fixed(value
, 4, 8, True)
388 self
.assertEqual(value
.bits
, 0x2F)
389 self
.assertEqual(value
.fract_width
, 4)
390 self
.assertEqual(value
.bit_width
, 8)
391 self
.assertEqual(value
.signed
, True)
393 def helper_tst_from_bits(self
, bit_width
, fract_width
):
395 for bits
in range(1 << bit_width
):
396 with self
.subTest(bit_width
=bit_width
,
397 fract_width
=fract_width
,
400 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
401 self
.assertEqual(value
.bit_width
, bit_width
)
402 self
.assertEqual(value
.fract_width
, fract_width
)
403 self
.assertEqual(value
.signed
, signed
)
404 self
.assertEqual(value
.bits
, bits
)
406 for bits
in range(-1 << (bit_width
- 1), 1 << (bit_width
- 1)):
407 with self
.subTest(bit_width
=bit_width
,
408 fract_width
=fract_width
,
411 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
412 self
.assertEqual(value
.bit_width
, bit_width
)
413 self
.assertEqual(value
.fract_width
, fract_width
)
414 self
.assertEqual(value
.signed
, signed
)
415 self
.assertEqual(value
.bits
, bits
)
417 def test_from_bits(self
):
418 for bit_width
in range(1, 5):
419 for fract_width
in range(bit_width
):
420 self
.helper_tst_from_bits(bit_width
, fract_width
)
423 self
.assertEqual(repr(Fixed
.from_bits(1, 2, 3, False)),
424 "Fixed.from_bits(1, 2, 3, False)")
425 self
.assertEqual(repr(Fixed
.from_bits(-4, 2, 3, True)),
426 "Fixed.from_bits(-4, 2, 3, True)")
427 self
.assertEqual(repr(Fixed
.from_bits(-4, 7, 10, True)),
428 "Fixed.from_bits(-4, 7, 10, True)")
430 def test_trunc(self
):
431 for i
in range(-8, 8):
432 value
= Fixed
.from_bits(i
, 2, 4, True)
433 with self
.subTest(value
=repr(value
)):
434 self
.assertEqual(math
.trunc(value
), math
.trunc(i
/ 4))
437 for i
in range(-8, 8):
438 value
= Fixed
.from_bits(i
, 2, 4, True)
439 with self
.subTest(value
=repr(value
)):
440 self
.assertEqual(int(value
), math
.trunc(value
))
442 def test_float(self
):
443 for i
in range(-8, 8):
444 value
= Fixed
.from_bits(i
, 2, 4, True)
445 with self
.subTest(value
=repr(value
)):
446 self
.assertEqual(float(value
), i
/ 4)
448 def test_floor(self
):
449 for i
in range(-8, 8):
450 value
= Fixed
.from_bits(i
, 2, 4, True)
451 with self
.subTest(value
=repr(value
)):
452 self
.assertEqual(math
.floor(value
), math
.floor(i
/ 4))
455 for i
in range(-8, 8):
456 value
= Fixed
.from_bits(i
, 2, 4, True)
457 with self
.subTest(value
=repr(value
)):
458 self
.assertEqual(math
.ceil(value
), math
.ceil(i
/ 4))
461 for i
in range(-8, 8):
462 value
= Fixed
.from_bits(i
, 2, 4, True)
463 expected
= -i
/ 4 if i
!= -8 else -2.0 # handle wrap-around
464 with self
.subTest(value
=repr(value
)):
465 self
.assertEqual(float(-value
), expected
)
468 for i
in range(-8, 8):
469 value
= Fixed
.from_bits(i
, 2, 4, True)
470 with self
.subTest(value
=repr(value
)):
472 self
.assertEqual(value
.bits
, i
)
475 for i
in range(-8, 8):
476 value
= Fixed
.from_bits(i
, 2, 4, True)
477 expected
= abs(i
) / 4 if i
!= -8 else -2.0 # handle wrap-around
478 with self
.subTest(value
=repr(value
)):
479 self
.assertEqual(float(abs(value
)), expected
)
482 for i
in range(-8, 8):
483 value
= Fixed
.from_bits(i
, 2, 4, True)
484 with self
.subTest(value
=repr(value
)):
485 self
.assertEqual(float(~value
), (~i
) / 4)
488 def get_test_values(max_bit_width
, include_int
):
489 for signed
in False, True:
491 for bits
in range(1 << max_bit_width
):
492 int_value
= Const
.normalize(bits
, (max_bit_width
, signed
))
494 for bit_width
in range(1, max_bit_width
):
495 for fract_width
in range(bit_width
+ 1):
496 for bits
in range(1 << bit_width
):
497 yield Fixed
.from_bits(bits
,
502 def binary_op_test_helper(self
,
505 width_combine_op
=max,
506 adjust_bits_op
=None):
507 def default_adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
508 return bits
<< (out_fract_width
- in_fract_width
)
509 if adjust_bits_op
is None:
510 adjust_bits_op
= default_adjust_bits_op
512 for lhs
in self
.get_test_values(max_bit_width
, True):
513 lhs_is_int
= isinstance(lhs
, int)
514 for rhs
in self
.get_test_values(max_bit_width
, not lhs_is_int
):
515 rhs_is_int
= isinstance(rhs
, int)
517 assert not rhs_is_int
518 lhs_int
= adjust_bits_op(lhs
, rhs
.fract_width
, 0)
519 int_result
= operation(lhs_int
, rhs
.bits
)
521 expected
= Fixed
.from_bits(int_result
,
526 expected
= int_result
528 rhs_int
= adjust_bits_op(rhs
, lhs
.fract_width
, 0)
529 int_result
= operation(lhs
.bits
, rhs_int
)
531 expected
= Fixed
.from_bits(int_result
,
536 expected
= int_result
537 elif lhs
.signed
!= rhs
.signed
:
540 fract_width
= width_combine_op(lhs
.fract_width
,
542 int_width
= width_combine_op(lhs
.bit_width
546 bit_width
= fract_width
+ int_width
547 lhs_int
= adjust_bits_op(lhs
.bits
,
550 rhs_int
= adjust_bits_op(rhs
.bits
,
553 int_result
= operation(lhs_int
, rhs_int
)
555 expected
= Fixed
.from_bits(int_result
,
560 expected
= int_result
561 with self
.subTest(lhs
=repr(lhs
),
563 expected
=repr(expected
)):
564 result
= operation(lhs
, rhs
)
566 self
.assertEqual(result
.bit_width
, expected
.bit_width
)
567 self
.assertEqual(result
.signed
, expected
.signed
)
568 self
.assertEqual(result
.fract_width
,
569 expected
.fract_width
)
570 self
.assertEqual(result
.bits
, expected
.bits
)
572 self
.assertEqual(result
, expected
)
575 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
+ rhs
)
578 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
- rhs
)
581 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
& rhs
)
584 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs | rhs
)
587 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs ^ rhs
)
590 def adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
592 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
* rhs
,
594 lambda l_width
, r_width
: l_width
+ r_width
,
604 self
.binary_op_test_helper(cmp, False)
607 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
< rhs
, False)
610 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
<= rhs
, False)
613 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
== rhs
, False)
616 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
!= rhs
, False)
619 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
> rhs
, False)
622 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
>= rhs
, False)
625 for v
in self
.get_test_values(6, False):
626 with self
.subTest(v
=repr(v
)):
627 self
.assertEqual(bool(v
), bool(v
.bits
))
630 self
.assertEqual(str(Fixed
.from_bits(0x1234, 0, 16, False)),
632 self
.assertEqual(str(Fixed
.from_bits(-0x1234, 0, 16, True)),
634 self
.assertEqual(str(Fixed
.from_bits(0x12345, 3, 20, True)),
636 self
.assertEqual(str(Fixed(123.625, 3, 12, True)),
639 self
.assertEqual(str(Fixed
.from_bits(0x1, 0, 20, True)),
641 self
.assertEqual(str(Fixed
.from_bits(0x2, 1, 20, True)),
643 self
.assertEqual(str(Fixed
.from_bits(0x4, 2, 20, True)),
645 self
.assertEqual(str(Fixed
.from_bits(0x9, 3, 20, True)),
647 self
.assertEqual(str(Fixed
.from_bits(0x12, 4, 20, True)),
649 self
.assertEqual(str(Fixed
.from_bits(0x24, 5, 20, True)),
651 self
.assertEqual(str(Fixed
.from_bits(0x48, 6, 20, True)),
653 self
.assertEqual(str(Fixed
.from_bits(0x91, 7, 20, True)),
655 self
.assertEqual(str(Fixed
.from_bits(0x123, 8, 20, True)),
657 self
.assertEqual(str(Fixed
.from_bits(0x246, 9, 20, True)),
659 self
.assertEqual(str(Fixed
.from_bits(0x48d, 10, 20, True)),
661 self
.assertEqual(str(Fixed
.from_bits(0x91a, 11, 20, True)),
663 self
.assertEqual(str(Fixed
.from_bits(0x1234, 12, 20, True)),
665 self
.assertEqual(str(Fixed
.from_bits(0x2468, 13, 20, True)),
667 self
.assertEqual(str(Fixed
.from_bits(0x48d1, 14, 20, True)),
669 self
.assertEqual(str(Fixed
.from_bits(0x91a2, 15, 20, True)),
671 self
.assertEqual(str(Fixed
.from_bits(0x12345, 16, 20, True)),
673 self
.assertEqual(str(Fixed
.from_bits(0x2468a, 17, 20, True)),
675 self
.assertEqual(str(Fixed
.from_bits(0x48d14, 18, 20, True)),
677 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, True)),
679 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, False)),
683 class TestFixedSqrtFn(unittest
.TestCase
):
684 def test_on_ints(self
):
685 for radicand
in range(-1, 32):
689 root
= math
.floor(math
.sqrt(radicand
))
690 remainder
= radicand
- root
* root
691 expected
= RootRemainder(root
, remainder
)
692 with self
.subTest(radicand
=radicand
, expected
=expected
):
693 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
696 remainder
= radicand
- root
* root
697 expected
= RootRemainder(root
, remainder
)
698 with self
.subTest(radicand
=radicand
, expected
=expected
):
699 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
701 def test_on_fixed(self
):
702 for signed
in False, True:
703 for bit_width
in range(1, 10):
704 for fract_width
in range(bit_width
):
705 for bits
in range(1 << bit_width
):
706 radicand
= Fixed
.from_bits(bits
,
712 root
= radicand
.with_value(math
.sqrt(float(radicand
)))
713 remainder
= radicand
- root
* root
714 expected
= RootRemainder(root
, remainder
)
715 with self
.subTest(radicand
=repr(radicand
),
716 expected
=repr(expected
)):
717 self
.assertEqual(repr(fixed_sqrt(radicand
)),
720 def test_misc_cases(self
):
723 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
724 (Fixed(2, 30, 32, False),
725 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
727 for radicand
, expected
in test_cases
:
728 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
729 self
.assertEqual(str(fixed_sqrt(radicand
)), expected
)
732 class TestFixedSqrt(unittest
.TestCase
):
733 def helper(self
, log2_radix
):
734 for bit_width
in range(1, 8):
735 for fract_width
in range(bit_width
):
736 for radicand_bits
in range(1 << bit_width
):
737 radicand
= Fixed
.from_bits(radicand_bits
,
741 root_remainder
= fixed_sqrt(radicand
)
742 with self
.subTest(radicand
=repr(radicand
),
743 root_remainder
=repr(root_remainder
),
744 log2_radix
=log2_radix
):
745 obj
= FixedSqrt(radicand
, log2_radix
)
746 for _
in range(250 * bit_width
):
747 self
.assertEqual(obj
.root
* obj
.root
,
749 self
.assertGreaterEqual(obj
.radicand
,
751 if obj
.calculate_stage():
754 self
.fail("infinite loop")
755 self
.assertEqual(obj
.root
* obj
.root
,
757 self
.assertGreaterEqual(obj
.radicand
,
759 self
.assertEqual(obj
.remainder
,
760 obj
.radicand
- obj
.root_squared
)
761 self
.assertEqual(obj
.root
, root_remainder
.root
)
762 self
.assertEqual(obj
.remainder
,
763 root_remainder
.remainder
)
765 def test_radix_2(self
):
768 def test_radix_4(self
):
771 def test_radix_8(self
):
774 def test_radix_16(self
):
778 class TestFixedRSqrtFn(unittest
.TestCase
):
780 for bits
in range(1, 1 << 5):
781 radicand
= Fixed
.from_bits(bits
, 5, 12, False)
782 float_root
= 1 / math
.sqrt(float(radicand
))
783 root
= radicand
.with_value(float_root
)
784 remainder
= 1 - root
* root
* radicand
785 expected
= RootRemainder(root
, remainder
)
786 with self
.subTest(radicand
=repr(radicand
),
787 expected
=repr(expected
)):
788 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
792 for signed
in False, True:
793 for bit_width
in range(1, 10):
794 for fract_width
in range(bit_width
):
795 for bits
in range(1 << bit_width
):
796 radicand
= Fixed
.from_bits(bits
,
802 float_root
= 1 / math
.sqrt(float(radicand
))
803 max_value
= radicand
.with_bits(
804 (1 << (bit_width
- signed
)) - 1)
805 if float_root
> float(max_value
):
808 root
= radicand
.with_value(float_root
)
809 remainder
= 1 - root
* root
* radicand
810 expected
= RootRemainder(root
, remainder
)
811 with self
.subTest(radicand
=repr(radicand
),
812 expected
=repr(expected
)):
813 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
816 def test_misc_cases(self
):
819 (Fixed(0.5, 30, 32, False),
820 "RootRemainder(fixed:0x1.6a09e664, "
821 "fixed:0x0.0000000596d014780000000)")
823 for radicand
, expected
in test_cases
:
824 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
825 self
.assertEqual(str(fixed_rsqrt(radicand
)), expected
)
828 class TestFixedRSqrt(unittest
.TestCase
):
829 def helper(self
, log2_radix
):
830 for bit_width
in range(1, 8):
831 for fract_width
in range(bit_width
):
832 for radicand_bits
in range(1, 1 << bit_width
):
833 radicand
= Fixed
.from_bits(radicand_bits
,
837 root_remainder
= fixed_rsqrt(radicand
)
838 with self
.subTest(radicand
=repr(radicand
),
839 root_remainder
=repr(root_remainder
),
840 log2_radix
=log2_radix
):
841 obj
= FixedRSqrt(radicand
, log2_radix
)
842 for _
in range(250 * bit_width
):
843 self
.assertEqual(obj
.radicand
* obj
.root
,
845 self
.assertEqual(obj
.radicand_root
* obj
.root
,
846 obj
.radicand_root_squared
)
847 self
.assertGreaterEqual(1,
848 obj
.radicand_root_squared
)
849 if obj
.calculate_stage():
852 self
.fail("infinite loop")
853 self
.assertEqual(obj
.radicand
* obj
.root
,
855 self
.assertEqual(obj
.radicand_root
* obj
.root
,
856 obj
.radicand_root_squared
)
857 self
.assertGreaterEqual(1,
858 obj
.radicand_root_squared
)
859 self
.assertEqual(obj
.remainder
,
860 1 - obj
.radicand_root_squared
)
861 self
.assertEqual(obj
.root
, root_remainder
.root
)
862 self
.assertEqual(obj
.remainder
,
863 root_remainder
.remainder
)
865 def test_radix_2(self
):
868 def test_radix_4(self
):
871 def test_radix_8(self
):
874 def test_radix_16(self
):
878 class TestFixedUDivRemSqrtRSqrt(unittest
.TestCase
):
880 def show_fixed(bits
, fract_width
, bit_width
):
881 fixed
= Fixed
.from_bits(bits
, fract_width
, bit_width
, False)
882 return f
"{str(fixed)}:{repr(fixed)}"
884 def check_invariants(self
,
892 self
.assertEqual(obj
.dividend
, dividend
)
893 self
.assertEqual(obj
.divisor_radicand
, divisor_radicand
)
894 self
.assertEqual(obj
.operation
, operation
)
895 self
.assertEqual(obj
.bit_width
, bit_width
)
896 self
.assertEqual(obj
.fract_width
, fract_width
)
897 self
.assertEqual(obj
.log2_radix
, log2_radix
)
898 self
.assertEqual(obj
.root_times_radicand
,
899 obj
.quotient_root
* obj
.divisor_radicand
)
900 self
.assertGreaterEqual(obj
.compare_lhs
, obj
.compare_rhs
)
901 self
.assertEqual(obj
.remainder
, obj
.compare_lhs
- obj
.compare_rhs
)
902 if operation
is Operation
.UDivRem
:
903 self
.assertEqual(obj
.compare_lhs
, obj
.dividend
<< fract_width
)
904 self
.assertEqual(obj
.compare_rhs
,
905 (obj
.quotient_root
* obj
.divisor_radicand
)
907 elif operation
is Operation
.SqrtRem
:
908 self
.assertEqual(obj
.compare_lhs
,
909 obj
.divisor_radicand
<< (fract_width
* 2))
910 self
.assertEqual(obj
.compare_rhs
,
911 (obj
.quotient_root
* obj
.quotient_root
)
914 assert operation
is Operation
.RSqrtRem
915 self
.assertEqual(obj
.compare_lhs
,
916 1 << (fract_width
* 3))
917 self
.assertEqual(obj
.compare_rhs
,
918 obj
.quotient_root
* obj
.quotient_root
919 * obj
.divisor_radicand
)
921 def handle_case(self
,
928 dividend_str
= self
.show_fixed(dividend
,
930 bit_width
+ fract_width
)
931 divisor_radicand_str
= self
.show_fixed(divisor_radicand
,
934 with self
.subTest(dividend
=dividend_str
,
935 divisor_radicand
=divisor_radicand_str
,
936 operation
=operation
.name
,
938 fract_width
=fract_width
,
939 log2_radix
=log2_radix
):
940 if operation
is Operation
.UDivRem
:
941 if divisor_radicand
== 0:
943 quotient_root
, remainder
= div_rem(dividend
,
947 remainder
<<= fract_width
948 elif operation
is Operation
.SqrtRem
:
949 root_remainder
= fixed_sqrt(Fixed
.from_bits(divisor_radicand
,
953 self
.assertEqual(root_remainder
.root
.bit_width
,
955 self
.assertEqual(root_remainder
.root
.fract_width
,
957 self
.assertEqual(root_remainder
.remainder
.bit_width
,
959 self
.assertEqual(root_remainder
.remainder
.fract_width
,
961 quotient_root
= root_remainder
.root
.bits
962 remainder
= root_remainder
.remainder
.bits
<< fract_width
964 assert operation
is Operation
.RSqrtRem
965 if divisor_radicand
== 0:
967 root_remainder
= fixed_rsqrt(Fixed
.from_bits(divisor_radicand
,
971 self
.assertEqual(root_remainder
.root
.bit_width
,
973 self
.assertEqual(root_remainder
.root
.fract_width
,
975 self
.assertEqual(root_remainder
.remainder
.bit_width
,
977 self
.assertEqual(root_remainder
.remainder
.fract_width
,
979 quotient_root
= root_remainder
.root
.bits
980 remainder
= root_remainder
.remainder
.bits
981 if quotient_root
>= (1 << bit_width
):
983 quotient_root_str
= self
.show_fixed(quotient_root
,
986 remainder_str
= self
.show_fixed(remainder
,
989 with self
.subTest(quotient_root
=quotient_root_str
,
990 remainder
=remainder_str
):
991 obj
= FixedUDivRemSqrtRSqrt(dividend
,
997 for _
in range(250 * bit_width
):
998 self
.check_invariants(dividend
,
1005 if obj
.calculate_stage():
1008 self
.fail("infinite loop")
1009 self
.check_invariants(dividend
,
1016 self
.assertEqual(obj
.quotient_root
, quotient_root
)
1017 self
.assertEqual(obj
.remainder
, remainder
)
1019 def helper(self
, log2_radix
, operation
):
1020 bit_width_range
= range(1, 8)
1021 if operation
is Operation
.UDivRem
:
1022 bit_width_range
= range(1, 6)
1023 for bit_width
in bit_width_range
:
1024 for fract_width
in range(bit_width
):
1025 for divisor_radicand
in range(1 << bit_width
):
1026 dividend_range
= range(1)
1027 if operation
is Operation
.UDivRem
:
1028 dividend_range
= range(1 << (bit_width
+ fract_width
))
1029 for dividend
in dividend_range
:
1030 self
.handle_case(dividend
,
1037 def test_radix_2_UDiv(self
):
1038 self
.helper(1, Operation
.UDivRem
)
1040 def test_radix_4_UDiv(self
):
1041 self
.helper(2, Operation
.UDivRem
)
1043 def test_radix_8_UDiv(self
):
1044 self
.helper(3, Operation
.UDivRem
)
1046 def test_radix_16_UDiv(self
):
1047 self
.helper(4, Operation
.UDivRem
)
1049 def test_radix_2_Sqrt(self
):
1050 self
.helper(1, Operation
.SqrtRem
)
1052 def test_radix_4_Sqrt(self
):
1053 self
.helper(2, Operation
.SqrtRem
)
1055 def test_radix_8_Sqrt(self
):
1056 self
.helper(3, Operation
.SqrtRem
)
1058 def test_radix_16_Sqrt(self
):
1059 self
.helper(4, Operation
.SqrtRem
)
1061 def test_radix_2_RSqrt(self
):
1062 self
.helper(1, Operation
.RSqrtRem
)
1064 def test_radix_4_RSqrt(self
):
1065 self
.helper(2, Operation
.RSqrtRem
)
1067 def test_radix_8_RSqrt(self
):
1068 self
.helper(3, Operation
.RSqrtRem
)
1070 def test_radix_16_RSqrt(self
):
1071 self
.helper(4, Operation
.RSqrtRem
)