1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 from nmigen
.hdl
.ast
import Const
5 from .algorithm
import (div_rem
, UnsignedDivRem
, DivRem
,
6 Fixed
, RootRemainder
, fixed_sqrt
, FixedSqrt
,
7 fixed_rsqrt
, FixedRSqrt
)
12 class TestDivRemFn(unittest
.TestCase
):
13 def test_signed(self
):
15 # numerator, denominator, quotient, remainder
128 (-8, -1, -8, 0), # overflows and wraps around
273 for (n
, d
, q
, r
) in test_cases
:
274 self
.assertEqual(div_rem(n
, d
, 4, True), (q
, r
))
276 def test_unsigned(self
):
283 # div_rem matches // and % for unsigned integers
286 self
.assertEqual(div_rem(n
, d
, 4, False), (q
, r
))
289 class TestUnsignedDivRem(unittest
.TestCase
):
290 def helper(self
, log2_radix
):
292 for n
in range(1 << bit_width
):
293 for d
in range(1 << bit_width
):
294 q
, r
= div_rem(n
, d
, bit_width
, False)
295 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
):
296 udr
= UnsignedDivRem(n
, d
, bit_width
, log2_radix
)
297 for _
in range(250 * bit_width
):
298 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
300 if udr
.calculate_stage():
303 self
.fail("infinite loop")
304 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
306 self
.assertEqual(udr
.quotient
, q
)
307 self
.assertEqual(udr
.remainder
, r
)
309 def test_radix_2(self
):
312 def test_radix_4(self
):
315 def test_radix_8(self
):
318 def test_radix_16(self
):
322 class TestDivRem(unittest
.TestCase
):
323 def helper(self
, log2_radix
):
325 for n
in range(1 << bit_width
):
326 for d
in range(1 << bit_width
):
327 for signed
in False, True:
328 n
= Const
.normalize(n
, (bit_width
, signed
))
329 d
= Const
.normalize(d
, (bit_width
, signed
))
330 q
, r
= div_rem(n
, d
, bit_width
, signed
)
331 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
, signed
=signed
):
332 dr
= DivRem(n
, d
, bit_width
, signed
, log2_radix
)
333 for _
in range(250 * bit_width
):
334 if dr
.calculate_stage():
337 self
.fail("infinite loop")
338 self
.assertEqual(dr
.quotient
, q
)
339 self
.assertEqual(dr
.remainder
, r
)
341 def test_radix_2(self
):
344 def test_radix_4(self
):
347 def test_radix_8(self
):
350 def test_radix_16(self
):
354 class TestFixed(unittest
.TestCase
):
355 def test_constructor(self
):
356 value
= Fixed(0, 0, 1, False)
357 self
.assertEqual(value
.bits
, 0)
358 self
.assertEqual(value
.fract_width
, 0)
359 self
.assertEqual(value
.bit_width
, 1)
360 self
.assertEqual(value
.signed
, False)
361 value
= Fixed(1, 2, 3, True)
362 self
.assertEqual(value
.bits
, -4)
363 self
.assertEqual(value
.fract_width
, 2)
364 self
.assertEqual(value
.bit_width
, 3)
365 self
.assertEqual(value
.signed
, True)
366 value
= Fixed(1, 2, 4, True)
367 self
.assertEqual(value
.bits
, 4)
368 self
.assertEqual(value
.fract_width
, 2)
369 self
.assertEqual(value
.bit_width
, 4)
370 self
.assertEqual(value
.signed
, True)
371 value
= Fixed(1.25, 4, 8, True)
372 self
.assertEqual(value
.bits
, 0x14)
373 self
.assertEqual(value
.fract_width
, 4)
374 self
.assertEqual(value
.bit_width
, 8)
375 self
.assertEqual(value
.signed
, True)
376 value
= Fixed(Fixed(2, 0, 12, False), 4, 8, True)
377 self
.assertEqual(value
.bits
, 0x20)
378 self
.assertEqual(value
.fract_width
, 4)
379 self
.assertEqual(value
.bit_width
, 8)
380 self
.assertEqual(value
.signed
, True)
381 value
= Fixed(0x2FF / 2 ** 8, 8, 12, False)
382 self
.assertEqual(value
.bits
, 0x2FF)
383 self
.assertEqual(value
.fract_width
, 8)
384 self
.assertEqual(value
.bit_width
, 12)
385 self
.assertEqual(value
.signed
, False)
386 value
= Fixed(value
, 4, 8, True)
387 self
.assertEqual(value
.bits
, 0x2F)
388 self
.assertEqual(value
.fract_width
, 4)
389 self
.assertEqual(value
.bit_width
, 8)
390 self
.assertEqual(value
.signed
, True)
392 def helper_test_from_bits(self
, bit_width
, fract_width
):
394 for bits
in range(1 << bit_width
):
395 with self
.subTest(bit_width
=bit_width
,
396 fract_width
=fract_width
,
399 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
400 self
.assertEqual(value
.bit_width
, bit_width
)
401 self
.assertEqual(value
.fract_width
, fract_width
)
402 self
.assertEqual(value
.signed
, signed
)
403 self
.assertEqual(value
.bits
, bits
)
405 for bits
in range(-1 << (bit_width
- 1), 1 << (bit_width
- 1)):
406 with self
.subTest(bit_width
=bit_width
,
407 fract_width
=fract_width
,
410 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
411 self
.assertEqual(value
.bit_width
, bit_width
)
412 self
.assertEqual(value
.fract_width
, fract_width
)
413 self
.assertEqual(value
.signed
, signed
)
414 self
.assertEqual(value
.bits
, bits
)
416 def test_from_bits(self
):
417 for bit_width
in range(1, 5):
418 for fract_width
in range(bit_width
):
419 self
.helper_test_from_bits(bit_width
, fract_width
)
422 self
.assertEqual(repr(Fixed
.from_bits(1, 2, 3, False)),
423 "Fixed.from_bits(1, 2, 3, False)")
424 self
.assertEqual(repr(Fixed
.from_bits(-4, 2, 3, True)),
425 "Fixed.from_bits(-4, 2, 3, True)")
426 self
.assertEqual(repr(Fixed
.from_bits(-4, 7, 10, True)),
427 "Fixed.from_bits(-4, 7, 10, True)")
429 def test_trunc(self
):
430 for i
in range(-8, 8):
431 value
= Fixed
.from_bits(i
, 2, 4, True)
432 with self
.subTest(value
=repr(value
)):
433 self
.assertEqual(math
.trunc(value
), math
.trunc(i
/ 4))
436 for i
in range(-8, 8):
437 value
= Fixed
.from_bits(i
, 2, 4, True)
438 with self
.subTest(value
=repr(value
)):
439 self
.assertEqual(int(value
), math
.trunc(value
))
441 def test_float(self
):
442 for i
in range(-8, 8):
443 value
= Fixed
.from_bits(i
, 2, 4, True)
444 with self
.subTest(value
=repr(value
)):
445 self
.assertEqual(float(value
), i
/ 4)
447 def test_floor(self
):
448 for i
in range(-8, 8):
449 value
= Fixed
.from_bits(i
, 2, 4, True)
450 with self
.subTest(value
=repr(value
)):
451 self
.assertEqual(math
.floor(value
), math
.floor(i
/ 4))
454 for i
in range(-8, 8):
455 value
= Fixed
.from_bits(i
, 2, 4, True)
456 with self
.subTest(value
=repr(value
)):
457 self
.assertEqual(math
.ceil(value
), math
.ceil(i
/ 4))
460 for i
in range(-8, 8):
461 value
= Fixed
.from_bits(i
, 2, 4, True)
462 expected
= -i
/ 4 if i
!= -8 else -2.0 # handle wrap-around
463 with self
.subTest(value
=repr(value
)):
464 self
.assertEqual(float(-value
), expected
)
467 for i
in range(-8, 8):
468 value
= Fixed
.from_bits(i
, 2, 4, True)
469 with self
.subTest(value
=repr(value
)):
471 self
.assertEqual(value
.bits
, i
)
474 for i
in range(-8, 8):
475 value
= Fixed
.from_bits(i
, 2, 4, True)
476 expected
= abs(i
) / 4 if i
!= -8 else -2.0 # handle wrap-around
477 with self
.subTest(value
=repr(value
)):
478 self
.assertEqual(float(abs(value
)), expected
)
481 for i
in range(-8, 8):
482 value
= Fixed
.from_bits(i
, 2, 4, True)
483 with self
.subTest(value
=repr(value
)):
484 self
.assertEqual(float(~value
), (~i
) / 4)
487 def get_test_values(max_bit_width
, include_int
):
488 for signed
in False, True:
490 for bits
in range(1 << max_bit_width
):
491 int_value
= Const
.normalize(bits
, (max_bit_width
, signed
))
493 for bit_width
in range(1, max_bit_width
):
494 for fract_width
in range(bit_width
+ 1):
495 for bits
in range(1 << bit_width
):
496 yield Fixed
.from_bits(bits
,
501 def binary_op_test_helper(self
,
504 width_combine_op
=max,
505 adjust_bits_op
=None):
506 def default_adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
507 return bits
<< (out_fract_width
- in_fract_width
)
508 if adjust_bits_op
is None:
509 adjust_bits_op
= default_adjust_bits_op
511 for lhs
in self
.get_test_values(max_bit_width
, True):
512 lhs_is_int
= isinstance(lhs
, int)
513 for rhs
in self
.get_test_values(max_bit_width
, not lhs_is_int
):
514 rhs_is_int
= isinstance(rhs
, int)
516 assert not rhs_is_int
517 lhs_int
= adjust_bits_op(lhs
, rhs
.fract_width
, 0)
518 int_result
= operation(lhs_int
, rhs
.bits
)
520 expected
= Fixed
.from_bits(int_result
,
525 expected
= int_result
527 rhs_int
= adjust_bits_op(rhs
, lhs
.fract_width
, 0)
528 int_result
= operation(lhs
.bits
, rhs_int
)
530 expected
= Fixed
.from_bits(int_result
,
535 expected
= int_result
536 elif lhs
.signed
!= rhs
.signed
:
539 fract_width
= width_combine_op(lhs
.fract_width
,
541 int_width
= width_combine_op(lhs
.bit_width
545 bit_width
= fract_width
+ int_width
546 lhs_int
= adjust_bits_op(lhs
.bits
,
549 rhs_int
= adjust_bits_op(rhs
.bits
,
552 int_result
= operation(lhs_int
, rhs_int
)
554 expected
= Fixed
.from_bits(int_result
,
559 expected
= int_result
560 with self
.subTest(lhs
=repr(lhs
),
562 expected
=repr(expected
)):
563 result
= operation(lhs
, rhs
)
565 self
.assertEqual(result
.bit_width
, expected
.bit_width
)
566 self
.assertEqual(result
.signed
, expected
.signed
)
567 self
.assertEqual(result
.fract_width
,
568 expected
.fract_width
)
569 self
.assertEqual(result
.bits
, expected
.bits
)
571 self
.assertEqual(result
, expected
)
574 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
+ rhs
)
577 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
- rhs
)
580 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
& rhs
)
583 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs | rhs
)
586 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs ^ rhs
)
589 def adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
591 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
* rhs
,
593 lambda l_width
, r_width
: l_width
+ r_width
,
603 self
.binary_op_test_helper(cmp, False)
606 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
< rhs
, False)
609 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
<= rhs
, False)
612 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
== rhs
, False)
615 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
!= rhs
, False)
618 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
> rhs
, False)
621 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
>= rhs
, False)
624 for v
in self
.get_test_values(6, False):
625 with self
.subTest(v
=repr(v
)):
626 self
.assertEqual(bool(v
), bool(v
.bits
))
629 self
.assertEqual(str(Fixed
.from_bits(0x1234, 0, 16, False)),
631 self
.assertEqual(str(Fixed
.from_bits(-0x1234, 0, 16, True)),
633 self
.assertEqual(str(Fixed
.from_bits(0x12345, 3, 20, True)),
635 self
.assertEqual(str(Fixed(123.625, 3, 12, True)),
638 self
.assertEqual(str(Fixed
.from_bits(0x1, 0, 20, True)),
640 self
.assertEqual(str(Fixed
.from_bits(0x2, 1, 20, True)),
642 self
.assertEqual(str(Fixed
.from_bits(0x4, 2, 20, True)),
644 self
.assertEqual(str(Fixed
.from_bits(0x9, 3, 20, True)),
646 self
.assertEqual(str(Fixed
.from_bits(0x12, 4, 20, True)),
648 self
.assertEqual(str(Fixed
.from_bits(0x24, 5, 20, True)),
650 self
.assertEqual(str(Fixed
.from_bits(0x48, 6, 20, True)),
652 self
.assertEqual(str(Fixed
.from_bits(0x91, 7, 20, True)),
654 self
.assertEqual(str(Fixed
.from_bits(0x123, 8, 20, True)),
656 self
.assertEqual(str(Fixed
.from_bits(0x246, 9, 20, True)),
658 self
.assertEqual(str(Fixed
.from_bits(0x48d, 10, 20, True)),
660 self
.assertEqual(str(Fixed
.from_bits(0x91a, 11, 20, True)),
662 self
.assertEqual(str(Fixed
.from_bits(0x1234, 12, 20, True)),
664 self
.assertEqual(str(Fixed
.from_bits(0x2468, 13, 20, True)),
666 self
.assertEqual(str(Fixed
.from_bits(0x48d1, 14, 20, True)),
668 self
.assertEqual(str(Fixed
.from_bits(0x91a2, 15, 20, True)),
670 self
.assertEqual(str(Fixed
.from_bits(0x12345, 16, 20, True)),
672 self
.assertEqual(str(Fixed
.from_bits(0x2468a, 17, 20, True)),
674 self
.assertEqual(str(Fixed
.from_bits(0x48d14, 18, 20, True)),
676 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, True)),
678 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, False)),
682 class TestFixedSqrtFn(unittest
.TestCase
):
683 def test_on_ints(self
):
684 for radicand
in range(-1, 32):
688 root
= math
.floor(math
.sqrt(radicand
))
689 remainder
= radicand
- root
* root
690 expected
= RootRemainder(root
, remainder
)
691 with self
.subTest(radicand
=radicand
, expected
=expected
):
692 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
695 remainder
= radicand
- root
* root
696 expected
= RootRemainder(root
, remainder
)
697 with self
.subTest(radicand
=radicand
, expected
=expected
):
698 self
.assertEqual(repr(fixed_sqrt(radicand
)), repr(expected
))
700 def test_on_fixed(self
):
701 for signed
in False, True:
702 for bit_width
in range(1, 10):
703 for fract_width
in range(bit_width
):
704 for bits
in range(1 << bit_width
):
705 radicand
= Fixed
.from_bits(bits
,
711 root
= radicand
.with_value(math
.sqrt(float(radicand
)))
712 remainder
= radicand
- root
* root
713 expected
= RootRemainder(root
, remainder
)
714 with self
.subTest(radicand
=repr(radicand
),
715 expected
=repr(expected
)):
716 self
.assertEqual(repr(fixed_sqrt(radicand
)),
719 def test_misc_cases(self
):
722 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
723 (Fixed(2, 30, 32, False),
724 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
726 for radicand
, expected
in test_cases
:
727 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
728 self
.assertEqual(str(fixed_sqrt(radicand
)), expected
)
731 class TestFixedSqrt(unittest
.TestCase
):
732 def helper(self
, log2_radix
):
733 for bit_width
in range(1, 8):
734 for fract_width
in range(bit_width
):
735 for radicand_bits
in range(1 << bit_width
):
736 radicand
= Fixed
.from_bits(radicand_bits
,
740 root_remainder
= fixed_sqrt(radicand
)
741 with self
.subTest(radicand
=repr(radicand
),
742 root_remainder
=repr(root_remainder
),
743 log2_radix
=log2_radix
):
744 obj
= FixedSqrt(radicand
, log2_radix
)
745 for _
in range(250 * bit_width
):
746 self
.assertEqual(obj
.root
* obj
.root
,
748 self
.assertGreaterEqual(obj
.radicand
,
750 if obj
.calculate_stage():
753 self
.fail("infinite loop")
754 self
.assertEqual(obj
.root
* obj
.root
,
756 self
.assertGreaterEqual(obj
.radicand
,
758 self
.assertEqual(obj
.remainder
,
759 obj
.radicand
- obj
.root_squared
)
760 self
.assertEqual(obj
.root
, root_remainder
.root
)
761 self
.assertEqual(obj
.remainder
,
762 root_remainder
.remainder
)
764 def test_radix_2(self
):
767 def test_radix_4(self
):
770 def test_radix_8(self
):
773 def test_radix_16(self
):
777 class TestFixedRSqrtFn(unittest
.TestCase
):
779 for bits
in range(1, 1 << 5):
780 radicand
= Fixed
.from_bits(bits
, 5, 12, False)
781 float_root
= 1 / math
.sqrt(float(radicand
))
782 root
= radicand
.with_value(float_root
)
783 remainder
= 1 - root
* root
* radicand
784 expected
= RootRemainder(root
, remainder
)
785 with self
.subTest(radicand
=repr(radicand
),
786 expected
=repr(expected
)):
787 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
791 for signed
in False, True:
792 for bit_width
in range(1, 10):
793 for fract_width
in range(bit_width
):
794 for bits
in range(1 << bit_width
):
795 radicand
= Fixed
.from_bits(bits
,
801 float_root
= 1 / math
.sqrt(float(radicand
))
802 max_value
= radicand
.with_bits(
803 (1 << (bit_width
- signed
)) - 1)
804 if float_root
> float(max_value
):
807 root
= radicand
.with_value(float_root
)
808 remainder
= 1 - root
* root
* radicand
809 expected
= RootRemainder(root
, remainder
)
810 with self
.subTest(radicand
=repr(radicand
),
811 expected
=repr(expected
)):
812 self
.assertEqual(repr(fixed_rsqrt(radicand
)),
815 def test_misc_cases(self
):
818 (Fixed(0.5, 30, 32, False),
819 "RootRemainder(fixed:0x1.6a09e664, "
820 "fixed:0x0.0000000596d014780000000)")
822 for radicand
, expected
in test_cases
:
823 with self
.subTest(radicand
=str(radicand
), expected
=expected
):
824 self
.assertEqual(str(fixed_rsqrt(radicand
)), expected
)
827 class TestFixedRSqrt(unittest
.TestCase
):
828 def helper(self
, log2_radix
):
829 for bit_width
in range(1, 8):
830 for fract_width
in range(bit_width
):
831 for radicand_bits
in range(1, 1 << bit_width
):
832 radicand
= Fixed
.from_bits(radicand_bits
,
836 root_remainder
= fixed_rsqrt(radicand
)
837 with self
.subTest(radicand
=repr(radicand
),
838 root_remainder
=repr(root_remainder
),
839 log2_radix
=log2_radix
):
840 obj
= FixedRSqrt(radicand
, log2_radix
)
841 for _
in range(250 * bit_width
):
842 self
.assertEqual(obj
.radicand
* obj
.root
,
844 self
.assertEqual(obj
.radicand_root
* obj
.root
,
845 obj
.radicand_root_squared
)
846 self
.assertGreaterEqual(1,
847 obj
.radicand_root_squared
)
848 if obj
.calculate_stage():
851 self
.fail("infinite loop")
852 self
.assertEqual(obj
.radicand
* obj
.root
,
854 self
.assertEqual(obj
.radicand_root
* obj
.root
,
855 obj
.radicand_root_squared
)
856 self
.assertGreaterEqual(1,
857 obj
.radicand_root_squared
)
858 self
.assertEqual(obj
.remainder
,
859 1 - obj
.radicand_root_squared
)
860 self
.assertEqual(obj
.root
, root_remainder
.root
)
861 self
.assertEqual(obj
.remainder
,
862 root_remainder
.remainder
)
864 def test_radix_2(self
):
867 def test_radix_4(self
):
870 def test_radix_8(self
):
873 def test_radix_16(self
):