From f0de69ba86859285336ae50167c7e4b4c43c4024 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 1 Jul 2019 00:01:32 -0700 Subject: [PATCH] added tests for rest of Fixed --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 6 +- .../div_rem_sqrt_rsqrt/test_algorithm.py | 162 +++++++++++++++--- 2 files changed, 146 insertions(+), 22 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index 7ec21510..199450ed 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -333,6 +333,10 @@ class Fixed: bits = self.bits * rhs_bits return self.from_bits(bits, fract_width, bit_width, self.signed) + def __rmul__(self, rhs): + """ Reverse Multiplication. """ + return self.__mul__(rhs) + @staticmethod def _cmp_impl(lhs, rhs, fract_width, bit_width, signed): if lhs < rhs: @@ -374,7 +378,7 @@ class Fixed: """ Greater Than or Equal.""" return self.cmp(rhs) >= 0 - def __bool__(self, rhs): + def __bool__(self): """ Convert to bool.""" return bool(self.bits) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index b1a944d9..a72f9243 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -482,27 +482,147 @@ class TestFixed(unittest.TestCase): with self.subTest(value=repr(value)): self.assertEqual(float(~value), (~i) / 4) - # TODO: add test for _binary_op - # TODO: add test for __add__ - # TODO: add test for __radd__ - # TODO: add test for __sub__ - # TODO: add test for __rsub__ - # TODO: add test for __and__ - # TODO: add test for __rand__ - # TODO: add test for __or__ - # TODO: add test for __ror__ - # TODO: add test for __xor__ - # TODO: add test for __rxor__ - # TODO: add test for __mul__ - # TODO: add test for _cmp_impl - # TODO: add test for cmp - # TODO: add test for __lt__ - # TODO: add test for __le__ - # TODO: add test for __eq__ - # TODO: add test for __ne__ - # TODO: add test for __gt__ - # TODO: add test for __ge__ - # TODO: add test for __bool__ + @staticmethod + def get_test_values(max_bit_width, include_int): + for signed in False, True: + if include_int: + for bits in range(1 << max_bit_width): + int_value = Const.normalize(bits, (max_bit_width, signed)) + yield int_value + for bit_width in range(1, max_bit_width): + for fract_width in range(bit_width + 1): + for bits in range(1 << bit_width): + yield Fixed.from_bits(bits, + fract_width, + bit_width, + signed) + + def binary_op_test_helper(self, + operation, + is_fixed=True, + width_combine_op=max, + adjust_bits_op=None): + def default_adjust_bits_op(bits, out_fract_width, in_fract_width): + return bits << (out_fract_width - in_fract_width) + if adjust_bits_op is None: + adjust_bits_op = default_adjust_bits_op + max_bit_width = 5 + for lhs in self.get_test_values(max_bit_width, True): + lhs_is_int = isinstance(lhs, int) + for rhs in self.get_test_values(max_bit_width, not lhs_is_int): + rhs_is_int = isinstance(rhs, int) + if lhs_is_int: + assert not rhs_is_int + lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0) + int_result = operation(lhs_int, rhs.bits) + if is_fixed: + expected = Fixed.from_bits(int_result, + rhs.fract_width, + rhs.bit_width, + rhs.signed) + else: + expected = int_result + elif rhs_is_int: + rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0) + int_result = operation(lhs.bits, rhs_int) + if is_fixed: + expected = Fixed.from_bits(int_result, + lhs.fract_width, + lhs.bit_width, + lhs.signed) + else: + expected = int_result + elif lhs.signed != rhs.signed: + continue + else: + fract_width = width_combine_op(lhs.fract_width, + rhs.fract_width) + int_width = width_combine_op(lhs.bit_width + - lhs.fract_width, + rhs.bit_width + - rhs.fract_width) + bit_width = fract_width + int_width + lhs_int = adjust_bits_op(lhs.bits, + fract_width, + lhs.fract_width) + rhs_int = adjust_bits_op(rhs.bits, + fract_width, + rhs.fract_width) + int_result = operation(lhs_int, rhs_int) + if is_fixed: + expected = Fixed.from_bits(int_result, + fract_width, + bit_width, + lhs.signed) + else: + expected = int_result + with self.subTest(lhs=repr(lhs), + rhs=repr(rhs), + expected=repr(expected)): + result = operation(lhs, rhs) + if is_fixed: + self.assertEqual(result.bit_width, expected.bit_width) + self.assertEqual(result.signed, expected.signed) + self.assertEqual(result.fract_width, + expected.fract_width) + self.assertEqual(result.bits, expected.bits) + else: + self.assertEqual(result, expected) + + def test_add(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs) + + def test_sub(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs) + + def test_and(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs) + + def test_or(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs) + + def test_xor(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs) + + def test_mul(self): + def adjust_bits_op(bits, out_fract_width, in_fract_width): + return bits + self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs, + True, + lambda l_width, r_width: l_width + r_width, + adjust_bits_op) + + def test_cmp(self): + def cmp(lhs, rhs): + if lhs < rhs: + return -1 + elif lhs > rhs: + return 1 + return 0 + self.binary_op_test_helper(cmp, False) + + def test_lt(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False) + + def test_le(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False) + + def test_eq(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False) + + def test_ne(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False) + + def test_gt(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False) + + def test_ge(self): + self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False) + + def test_bool(self): + for v in self.get_test_values(6, False): + with self.subTest(v=repr(v)): + self.assertEqual(bool(v), bool(v.bits)) def test_str(self): self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)), -- 2.30.2