add tests for integer and fractional division
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 from nmigen.hdl.ast import Const
5 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
6 Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
7 fixed_rsqrt, FixedRSqrt, Operation,
8 FixedUDivRemSqrtRSqrt)
9 import unittest
10 import math
11
12
13 class TestDivRemFn(unittest.TestCase):
14 def test_signed(self):
15 test_cases = [
16 # numerator, denominator, quotient, remainder
17 (-8, -8, 1, 0),
18 (-7, -8, 0, -7),
19 (-6, -8, 0, -6),
20 (-5, -8, 0, -5),
21 (-4, -8, 0, -4),
22 (-3, -8, 0, -3),
23 (-2, -8, 0, -2),
24 (-1, -8, 0, -1),
25 (0, -8, 0, 0),
26 (1, -8, 0, 1),
27 (2, -8, 0, 2),
28 (3, -8, 0, 3),
29 (4, -8, 0, 4),
30 (5, -8, 0, 5),
31 (6, -8, 0, 6),
32 (7, -8, 0, 7),
33 (-8, -7, 1, -1),
34 (-7, -7, 1, 0),
35 (-6, -7, 0, -6),
36 (-5, -7, 0, -5),
37 (-4, -7, 0, -4),
38 (-3, -7, 0, -3),
39 (-2, -7, 0, -2),
40 (-1, -7, 0, -1),
41 (0, -7, 0, 0),
42 (1, -7, 0, 1),
43 (2, -7, 0, 2),
44 (3, -7, 0, 3),
45 (4, -7, 0, 4),
46 (5, -7, 0, 5),
47 (6, -7, 0, 6),
48 (7, -7, -1, 0),
49 (-8, -6, 1, -2),
50 (-7, -6, 1, -1),
51 (-6, -6, 1, 0),
52 (-5, -6, 0, -5),
53 (-4, -6, 0, -4),
54 (-3, -6, 0, -3),
55 (-2, -6, 0, -2),
56 (-1, -6, 0, -1),
57 (0, -6, 0, 0),
58 (1, -6, 0, 1),
59 (2, -6, 0, 2),
60 (3, -6, 0, 3),
61 (4, -6, 0, 4),
62 (5, -6, 0, 5),
63 (6, -6, -1, 0),
64 (7, -6, -1, 1),
65 (-8, -5, 1, -3),
66 (-7, -5, 1, -2),
67 (-6, -5, 1, -1),
68 (-5, -5, 1, 0),
69 (-4, -5, 0, -4),
70 (-3, -5, 0, -3),
71 (-2, -5, 0, -2),
72 (-1, -5, 0, -1),
73 (0, -5, 0, 0),
74 (1, -5, 0, 1),
75 (2, -5, 0, 2),
76 (3, -5, 0, 3),
77 (4, -5, 0, 4),
78 (5, -5, -1, 0),
79 (6, -5, -1, 1),
80 (7, -5, -1, 2),
81 (-8, -4, 2, 0),
82 (-7, -4, 1, -3),
83 (-6, -4, 1, -2),
84 (-5, -4, 1, -1),
85 (-4, -4, 1, 0),
86 (-3, -4, 0, -3),
87 (-2, -4, 0, -2),
88 (-1, -4, 0, -1),
89 (0, -4, 0, 0),
90 (1, -4, 0, 1),
91 (2, -4, 0, 2),
92 (3, -4, 0, 3),
93 (4, -4, -1, 0),
94 (5, -4, -1, 1),
95 (6, -4, -1, 2),
96 (7, -4, -1, 3),
97 (-8, -3, 2, -2),
98 (-7, -3, 2, -1),
99 (-6, -3, 2, 0),
100 (-5, -3, 1, -2),
101 (-4, -3, 1, -1),
102 (-3, -3, 1, 0),
103 (-2, -3, 0, -2),
104 (-1, -3, 0, -1),
105 (0, -3, 0, 0),
106 (1, -3, 0, 1),
107 (2, -3, 0, 2),
108 (3, -3, -1, 0),
109 (4, -3, -1, 1),
110 (5, -3, -1, 2),
111 (6, -3, -2, 0),
112 (7, -3, -2, 1),
113 (-8, -2, 4, 0),
114 (-7, -2, 3, -1),
115 (-6, -2, 3, 0),
116 (-5, -2, 2, -1),
117 (-4, -2, 2, 0),
118 (-3, -2, 1, -1),
119 (-2, -2, 1, 0),
120 (-1, -2, 0, -1),
121 (0, -2, 0, 0),
122 (1, -2, 0, 1),
123 (2, -2, -1, 0),
124 (3, -2, -1, 1),
125 (4, -2, -2, 0),
126 (5, -2, -2, 1),
127 (6, -2, -3, 0),
128 (7, -2, -3, 1),
129 (-8, -1, -8, 0), # overflows and wraps around
130 (-7, -1, 7, 0),
131 (-6, -1, 6, 0),
132 (-5, -1, 5, 0),
133 (-4, -1, 4, 0),
134 (-3, -1, 3, 0),
135 (-2, -1, 2, 0),
136 (-1, -1, 1, 0),
137 (0, -1, 0, 0),
138 (1, -1, -1, 0),
139 (2, -1, -2, 0),
140 (3, -1, -3, 0),
141 (4, -1, -4, 0),
142 (5, -1, -5, 0),
143 (6, -1, -6, 0),
144 (7, -1, -7, 0),
145 (-8, 0, -1, -8),
146 (-7, 0, -1, -7),
147 (-6, 0, -1, -6),
148 (-5, 0, -1, -5),
149 (-4, 0, -1, -4),
150 (-3, 0, -1, -3),
151 (-2, 0, -1, -2),
152 (-1, 0, -1, -1),
153 (0, 0, -1, 0),
154 (1, 0, -1, 1),
155 (2, 0, -1, 2),
156 (3, 0, -1, 3),
157 (4, 0, -1, 4),
158 (5, 0, -1, 5),
159 (6, 0, -1, 6),
160 (7, 0, -1, 7),
161 (-8, 1, -8, 0),
162 (-7, 1, -7, 0),
163 (-6, 1, -6, 0),
164 (-5, 1, -5, 0),
165 (-4, 1, -4, 0),
166 (-3, 1, -3, 0),
167 (-2, 1, -2, 0),
168 (-1, 1, -1, 0),
169 (0, 1, 0, 0),
170 (1, 1, 1, 0),
171 (2, 1, 2, 0),
172 (3, 1, 3, 0),
173 (4, 1, 4, 0),
174 (5, 1, 5, 0),
175 (6, 1, 6, 0),
176 (7, 1, 7, 0),
177 (-8, 2, -4, 0),
178 (-7, 2, -3, -1),
179 (-6, 2, -3, 0),
180 (-5, 2, -2, -1),
181 (-4, 2, -2, 0),
182 (-3, 2, -1, -1),
183 (-2, 2, -1, 0),
184 (-1, 2, 0, -1),
185 (0, 2, 0, 0),
186 (1, 2, 0, 1),
187 (2, 2, 1, 0),
188 (3, 2, 1, 1),
189 (4, 2, 2, 0),
190 (5, 2, 2, 1),
191 (6, 2, 3, 0),
192 (7, 2, 3, 1),
193 (-8, 3, -2, -2),
194 (-7, 3, -2, -1),
195 (-6, 3, -2, 0),
196 (-5, 3, -1, -2),
197 (-4, 3, -1, -1),
198 (-3, 3, -1, 0),
199 (-2, 3, 0, -2),
200 (-1, 3, 0, -1),
201 (0, 3, 0, 0),
202 (1, 3, 0, 1),
203 (2, 3, 0, 2),
204 (3, 3, 1, 0),
205 (4, 3, 1, 1),
206 (5, 3, 1, 2),
207 (6, 3, 2, 0),
208 (7, 3, 2, 1),
209 (-8, 4, -2, 0),
210 (-7, 4, -1, -3),
211 (-6, 4, -1, -2),
212 (-5, 4, -1, -1),
213 (-4, 4, -1, 0),
214 (-3, 4, 0, -3),
215 (-2, 4, 0, -2),
216 (-1, 4, 0, -1),
217 (0, 4, 0, 0),
218 (1, 4, 0, 1),
219 (2, 4, 0, 2),
220 (3, 4, 0, 3),
221 (4, 4, 1, 0),
222 (5, 4, 1, 1),
223 (6, 4, 1, 2),
224 (7, 4, 1, 3),
225 (-8, 5, -1, -3),
226 (-7, 5, -1, -2),
227 (-6, 5, -1, -1),
228 (-5, 5, -1, 0),
229 (-4, 5, 0, -4),
230 (-3, 5, 0, -3),
231 (-2, 5, 0, -2),
232 (-1, 5, 0, -1),
233 (0, 5, 0, 0),
234 (1, 5, 0, 1),
235 (2, 5, 0, 2),
236 (3, 5, 0, 3),
237 (4, 5, 0, 4),
238 (5, 5, 1, 0),
239 (6, 5, 1, 1),
240 (7, 5, 1, 2),
241 (-8, 6, -1, -2),
242 (-7, 6, -1, -1),
243 (-6, 6, -1, 0),
244 (-5, 6, 0, -5),
245 (-4, 6, 0, -4),
246 (-3, 6, 0, -3),
247 (-2, 6, 0, -2),
248 (-1, 6, 0, -1),
249 (0, 6, 0, 0),
250 (1, 6, 0, 1),
251 (2, 6, 0, 2),
252 (3, 6, 0, 3),
253 (4, 6, 0, 4),
254 (5, 6, 0, 5),
255 (6, 6, 1, 0),
256 (7, 6, 1, 1),
257 (-8, 7, -1, -1),
258 (-7, 7, -1, 0),
259 (-6, 7, 0, -6),
260 (-5, 7, 0, -5),
261 (-4, 7, 0, -4),
262 (-3, 7, 0, -3),
263 (-2, 7, 0, -2),
264 (-1, 7, 0, -1),
265 (0, 7, 0, 0),
266 (1, 7, 0, 1),
267 (2, 7, 0, 2),
268 (3, 7, 0, 3),
269 (4, 7, 0, 4),
270 (5, 7, 0, 5),
271 (6, 7, 0, 6),
272 (7, 7, 1, 0),
273 ]
274 for (n, d, q, r) in test_cases:
275 self.assertEqual(div_rem(n, d, 4, True), (q, r))
276
277 def test_unsigned(self):
278 for n in range(16):
279 for d in range(16):
280 if d == 0:
281 q = 16 - 1
282 r = n
283 else:
284 # div_rem matches // and % for unsigned integers
285 q = n // d
286 r = n % d
287 self.assertEqual(div_rem(n, d, 4, False), (q, r))
288
289
290 class TestUnsignedDivRem(unittest.TestCase):
291 def helper(self, log2_radix):
292 bit_width = 4
293 for n in range(1 << bit_width):
294 for d in range(1 << bit_width):
295 q, r = div_rem(n, d, bit_width, False)
296 with self.subTest(n=n, d=d, q=q, r=r):
297 udr = UnsignedDivRem(n, d, bit_width, log2_radix)
298 for _ in range(250 * bit_width):
299 self.assertEqual(udr.dividend, n)
300 self.assertEqual(udr.divisor, d)
301 self.assertEqual(udr.quotient_times_divisor,
302 udr.quotient * udr.divisor)
303 self.assertGreaterEqual(udr.dividend,
304 udr.quotient_times_divisor)
305 if udr.calculate_stage():
306 break
307 else:
308 self.fail("infinite loop")
309 self.assertEqual(udr.dividend, n)
310 self.assertEqual(udr.divisor, d)
311 self.assertEqual(udr.quotient_times_divisor,
312 udr.quotient * udr.divisor)
313 self.assertGreaterEqual(udr.dividend,
314 udr.quotient_times_divisor)
315 self.assertEqual(udr.quotient, q)
316 self.assertEqual(udr.remainder, r)
317
318 def test_radix_2(self):
319 self.helper(1)
320
321 def test_radix_4(self):
322 self.helper(2)
323
324 def test_radix_8(self):
325 self.helper(3)
326
327 def test_radix_16(self):
328 self.helper(4)
329
330
331 class TestDivRem(unittest.TestCase):
332 def helper(self, log2_radix):
333 bit_width = 4
334 for n in range(1 << bit_width):
335 for d in range(1 << bit_width):
336 for signed in False, True:
337 n = Const.normalize(n, (bit_width, signed))
338 d = Const.normalize(d, (bit_width, signed))
339 q, r = div_rem(n, d, bit_width, signed)
340 with self.subTest(n=n, d=d, q=q, r=r, signed=signed):
341 dr = DivRem(n, d, bit_width, signed, log2_radix)
342 for _ in range(250 * bit_width):
343 if dr.calculate_stage():
344 break
345 else:
346 self.fail("infinite loop")
347 self.assertEqual(dr.quotient, q)
348 self.assertEqual(dr.remainder, r)
349
350 def test_radix_2(self):
351 self.helper(1)
352
353 def test_radix_4(self):
354 self.helper(2)
355
356 def test_radix_8(self):
357 self.helper(3)
358
359 def test_radix_16(self):
360 self.helper(4)
361
362
363 class TestFixed(unittest.TestCase):
364 def test_constructor(self):
365 value = Fixed(0, 0, 1, False)
366 self.assertEqual(value.bits, 0)
367 self.assertEqual(value.fract_width, 0)
368 self.assertEqual(value.bit_width, 1)
369 self.assertEqual(value.signed, False)
370 value = Fixed(1, 2, 3, True)
371 self.assertEqual(value.bits, -4)
372 self.assertEqual(value.fract_width, 2)
373 self.assertEqual(value.bit_width, 3)
374 self.assertEqual(value.signed, True)
375 value = Fixed(1, 2, 4, True)
376 self.assertEqual(value.bits, 4)
377 self.assertEqual(value.fract_width, 2)
378 self.assertEqual(value.bit_width, 4)
379 self.assertEqual(value.signed, True)
380 value = Fixed(1.25, 4, 8, True)
381 self.assertEqual(value.bits, 0x14)
382 self.assertEqual(value.fract_width, 4)
383 self.assertEqual(value.bit_width, 8)
384 self.assertEqual(value.signed, True)
385 value = Fixed(Fixed(2, 0, 12, False), 4, 8, True)
386 self.assertEqual(value.bits, 0x20)
387 self.assertEqual(value.fract_width, 4)
388 self.assertEqual(value.bit_width, 8)
389 self.assertEqual(value.signed, True)
390 value = Fixed(0x2FF / 2 ** 8, 8, 12, False)
391 self.assertEqual(value.bits, 0x2FF)
392 self.assertEqual(value.fract_width, 8)
393 self.assertEqual(value.bit_width, 12)
394 self.assertEqual(value.signed, False)
395 value = Fixed(value, 4, 8, True)
396 self.assertEqual(value.bits, 0x2F)
397 self.assertEqual(value.fract_width, 4)
398 self.assertEqual(value.bit_width, 8)
399 self.assertEqual(value.signed, True)
400
401 def helper_tst_from_bits(self, bit_width, fract_width):
402 signed = False
403 for bits in range(1 << bit_width):
404 with self.subTest(bit_width=bit_width,
405 fract_width=fract_width,
406 signed=signed,
407 bits=hex(bits)):
408 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
409 self.assertEqual(value.bit_width, bit_width)
410 self.assertEqual(value.fract_width, fract_width)
411 self.assertEqual(value.signed, signed)
412 self.assertEqual(value.bits, bits)
413 signed = True
414 for bits in range(-1 << (bit_width - 1), 1 << (bit_width - 1)):
415 with self.subTest(bit_width=bit_width,
416 fract_width=fract_width,
417 signed=signed,
418 bits=hex(bits)):
419 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
420 self.assertEqual(value.bit_width, bit_width)
421 self.assertEqual(value.fract_width, fract_width)
422 self.assertEqual(value.signed, signed)
423 self.assertEqual(value.bits, bits)
424
425 def test_from_bits(self):
426 for bit_width in range(1, 5):
427 for fract_width in range(bit_width):
428 self.helper_tst_from_bits(bit_width, fract_width)
429
430 def test_repr(self):
431 self.assertEqual(repr(Fixed.from_bits(1, 2, 3, False)),
432 "Fixed.from_bits(1, 2, 3, False)")
433 self.assertEqual(repr(Fixed.from_bits(-4, 2, 3, True)),
434 "Fixed.from_bits(-4, 2, 3, True)")
435 self.assertEqual(repr(Fixed.from_bits(-4, 7, 10, True)),
436 "Fixed.from_bits(-4, 7, 10, True)")
437
438 def test_trunc(self):
439 for i in range(-8, 8):
440 value = Fixed.from_bits(i, 2, 4, True)
441 with self.subTest(value=repr(value)):
442 self.assertEqual(math.trunc(value), math.trunc(i / 4))
443
444 def test_int(self):
445 for i in range(-8, 8):
446 value = Fixed.from_bits(i, 2, 4, True)
447 with self.subTest(value=repr(value)):
448 self.assertEqual(int(value), math.trunc(value))
449
450 def test_float(self):
451 for i in range(-8, 8):
452 value = Fixed.from_bits(i, 2, 4, True)
453 with self.subTest(value=repr(value)):
454 self.assertEqual(float(value), i / 4)
455
456 def test_floor(self):
457 for i in range(-8, 8):
458 value = Fixed.from_bits(i, 2, 4, True)
459 with self.subTest(value=repr(value)):
460 self.assertEqual(math.floor(value), math.floor(i / 4))
461
462 def test_ceil(self):
463 for i in range(-8, 8):
464 value = Fixed.from_bits(i, 2, 4, True)
465 with self.subTest(value=repr(value)):
466 self.assertEqual(math.ceil(value), math.ceil(i / 4))
467
468 def test_neg(self):
469 for i in range(-8, 8):
470 value = Fixed.from_bits(i, 2, 4, True)
471 expected = -i / 4 if i != -8 else -2.0 # handle wrap-around
472 with self.subTest(value=repr(value)):
473 self.assertEqual(float(-value), expected)
474
475 def test_pos(self):
476 for i in range(-8, 8):
477 value = Fixed.from_bits(i, 2, 4, True)
478 with self.subTest(value=repr(value)):
479 value = +value
480 self.assertEqual(value.bits, i)
481
482 def test_abs(self):
483 for i in range(-8, 8):
484 value = Fixed.from_bits(i, 2, 4, True)
485 expected = abs(i) / 4 if i != -8 else -2.0 # handle wrap-around
486 with self.subTest(value=repr(value)):
487 self.assertEqual(float(abs(value)), expected)
488
489 def test_not(self):
490 for i in range(-8, 8):
491 value = Fixed.from_bits(i, 2, 4, True)
492 with self.subTest(value=repr(value)):
493 self.assertEqual(float(~value), (~i) / 4)
494
495 @staticmethod
496 def get_test_values(max_bit_width, include_int):
497 for signed in False, True:
498 if include_int:
499 for bits in range(1 << max_bit_width):
500 int_value = Const.normalize(bits, (max_bit_width, signed))
501 yield int_value
502 for bit_width in range(1, max_bit_width):
503 for fract_width in range(bit_width + 1):
504 for bits in range(1 << bit_width):
505 yield Fixed.from_bits(bits,
506 fract_width,
507 bit_width,
508 signed)
509
510 def binary_op_test_helper(self,
511 operation,
512 is_fixed=True,
513 width_combine_op=max,
514 adjust_bits_op=None):
515 def default_adjust_bits_op(bits, out_fract_width, in_fract_width):
516 return bits << (out_fract_width - in_fract_width)
517 if adjust_bits_op is None:
518 adjust_bits_op = default_adjust_bits_op
519 max_bit_width = 5
520 for lhs in self.get_test_values(max_bit_width, True):
521 lhs_is_int = isinstance(lhs, int)
522 for rhs in self.get_test_values(max_bit_width, not lhs_is_int):
523 rhs_is_int = isinstance(rhs, int)
524 if lhs_is_int:
525 assert not rhs_is_int
526 lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0)
527 int_result = operation(lhs_int, rhs.bits)
528 if is_fixed:
529 expected = Fixed.from_bits(int_result,
530 rhs.fract_width,
531 rhs.bit_width,
532 rhs.signed)
533 else:
534 expected = int_result
535 elif rhs_is_int:
536 rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0)
537 int_result = operation(lhs.bits, rhs_int)
538 if is_fixed:
539 expected = Fixed.from_bits(int_result,
540 lhs.fract_width,
541 lhs.bit_width,
542 lhs.signed)
543 else:
544 expected = int_result
545 elif lhs.signed != rhs.signed:
546 continue
547 else:
548 fract_width = width_combine_op(lhs.fract_width,
549 rhs.fract_width)
550 int_width = width_combine_op(lhs.bit_width
551 - lhs.fract_width,
552 rhs.bit_width
553 - rhs.fract_width)
554 bit_width = fract_width + int_width
555 lhs_int = adjust_bits_op(lhs.bits,
556 fract_width,
557 lhs.fract_width)
558 rhs_int = adjust_bits_op(rhs.bits,
559 fract_width,
560 rhs.fract_width)
561 int_result = operation(lhs_int, rhs_int)
562 if is_fixed:
563 expected = Fixed.from_bits(int_result,
564 fract_width,
565 bit_width,
566 lhs.signed)
567 else:
568 expected = int_result
569 with self.subTest(lhs=repr(lhs),
570 rhs=repr(rhs),
571 expected=repr(expected)):
572 result = operation(lhs, rhs)
573 if is_fixed:
574 self.assertEqual(result.bit_width, expected.bit_width)
575 self.assertEqual(result.signed, expected.signed)
576 self.assertEqual(result.fract_width,
577 expected.fract_width)
578 self.assertEqual(result.bits, expected.bits)
579 else:
580 self.assertEqual(result, expected)
581
582 def test_add(self):
583 self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs)
584
585 def test_sub(self):
586 self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs)
587
588 def test_and(self):
589 self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs)
590
591 def test_or(self):
592 self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs)
593
594 def test_xor(self):
595 self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs)
596
597 def test_mul(self):
598 def adjust_bits_op(bits, out_fract_width, in_fract_width):
599 return bits
600 self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs,
601 True,
602 lambda l_width, r_width: l_width + r_width,
603 adjust_bits_op)
604
605 def test_cmp(self):
606 def cmp(lhs, rhs):
607 if lhs < rhs:
608 return -1
609 elif lhs > rhs:
610 return 1
611 return 0
612 self.binary_op_test_helper(cmp, False)
613
614 def test_lt(self):
615 self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False)
616
617 def test_le(self):
618 self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False)
619
620 def test_eq(self):
621 self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False)
622
623 def test_ne(self):
624 self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False)
625
626 def test_gt(self):
627 self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False)
628
629 def test_ge(self):
630 self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False)
631
632 def test_bool(self):
633 for v in self.get_test_values(6, False):
634 with self.subTest(v=repr(v)):
635 self.assertEqual(bool(v), bool(v.bits))
636
637 def test_str(self):
638 self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)),
639 "fixed:0x1234.")
640 self.assertEqual(str(Fixed.from_bits(-0x1234, 0, 16, True)),
641 "fixed:-0x1234.")
642 self.assertEqual(str(Fixed.from_bits(0x12345, 3, 20, True)),
643 "fixed:0x2468.a")
644 self.assertEqual(str(Fixed(123.625, 3, 12, True)),
645 "fixed:0x7b.a")
646
647 self.assertEqual(str(Fixed.from_bits(0x1, 0, 20, True)),
648 "fixed:0x1.")
649 self.assertEqual(str(Fixed.from_bits(0x2, 1, 20, True)),
650 "fixed:0x1.0")
651 self.assertEqual(str(Fixed.from_bits(0x4, 2, 20, True)),
652 "fixed:0x1.0")
653 self.assertEqual(str(Fixed.from_bits(0x9, 3, 20, True)),
654 "fixed:0x1.2")
655 self.assertEqual(str(Fixed.from_bits(0x12, 4, 20, True)),
656 "fixed:0x1.2")
657 self.assertEqual(str(Fixed.from_bits(0x24, 5, 20, True)),
658 "fixed:0x1.20")
659 self.assertEqual(str(Fixed.from_bits(0x48, 6, 20, True)),
660 "fixed:0x1.20")
661 self.assertEqual(str(Fixed.from_bits(0x91, 7, 20, True)),
662 "fixed:0x1.22")
663 self.assertEqual(str(Fixed.from_bits(0x123, 8, 20, True)),
664 "fixed:0x1.23")
665 self.assertEqual(str(Fixed.from_bits(0x246, 9, 20, True)),
666 "fixed:0x1.230")
667 self.assertEqual(str(Fixed.from_bits(0x48d, 10, 20, True)),
668 "fixed:0x1.234")
669 self.assertEqual(str(Fixed.from_bits(0x91a, 11, 20, True)),
670 "fixed:0x1.234")
671 self.assertEqual(str(Fixed.from_bits(0x1234, 12, 20, True)),
672 "fixed:0x1.234")
673 self.assertEqual(str(Fixed.from_bits(0x2468, 13, 20, True)),
674 "fixed:0x1.2340")
675 self.assertEqual(str(Fixed.from_bits(0x48d1, 14, 20, True)),
676 "fixed:0x1.2344")
677 self.assertEqual(str(Fixed.from_bits(0x91a2, 15, 20, True)),
678 "fixed:0x1.2344")
679 self.assertEqual(str(Fixed.from_bits(0x12345, 16, 20, True)),
680 "fixed:0x1.2345")
681 self.assertEqual(str(Fixed.from_bits(0x2468a, 17, 20, True)),
682 "fixed:0x1.23450")
683 self.assertEqual(str(Fixed.from_bits(0x48d14, 18, 20, True)),
684 "fixed:0x1.23450")
685 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, True)),
686 "fixed:-0x0.dcbb0")
687 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, False)),
688 "fixed:0x1.23450")
689
690
691 class TestFixedSqrtFn(unittest.TestCase):
692 def test_on_ints(self):
693 for radicand in range(-1, 32):
694 if radicand < 0:
695 expected = None
696 else:
697 root = math.floor(math.sqrt(radicand))
698 remainder = radicand - root * root
699 expected = RootRemainder(root, remainder)
700 with self.subTest(radicand=radicand, expected=expected):
701 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
702 radicand = 2 << 64
703 root = 0x16A09E667
704 remainder = radicand - root * root
705 expected = RootRemainder(root, remainder)
706 with self.subTest(radicand=radicand, expected=expected):
707 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
708
709 def test_on_fixed(self):
710 for signed in False, True:
711 for bit_width in range(1, 10):
712 for fract_width in range(bit_width):
713 for bits in range(1 << bit_width):
714 radicand = Fixed.from_bits(bits,
715 fract_width,
716 bit_width,
717 signed)
718 if radicand < 0:
719 continue
720 root = radicand.with_value(math.sqrt(float(radicand)))
721 remainder = radicand - root * root
722 expected = RootRemainder(root, remainder)
723 with self.subTest(radicand=repr(radicand),
724 expected=repr(expected)):
725 self.assertEqual(repr(fixed_sqrt(radicand)),
726 repr(expected))
727
728 def test_misc_cases(self):
729 test_cases = [
730 # radicand, expected
731 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
732 (Fixed(2, 30, 32, False),
733 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
734 ]
735 for radicand, expected in test_cases:
736 with self.subTest(radicand=str(radicand), expected=expected):
737 self.assertEqual(str(fixed_sqrt(radicand)), expected)
738
739
740 class TestFixedSqrt(unittest.TestCase):
741 def helper(self, log2_radix):
742 for bit_width in range(1, 8):
743 for fract_width in range(bit_width):
744 for radicand_bits in range(1 << bit_width):
745 radicand = Fixed.from_bits(radicand_bits,
746 fract_width,
747 bit_width,
748 False)
749 root_remainder = fixed_sqrt(radicand)
750 with self.subTest(radicand=repr(radicand),
751 root_remainder=repr(root_remainder),
752 log2_radix=log2_radix):
753 obj = FixedSqrt(radicand, log2_radix)
754 for _ in range(250 * bit_width):
755 self.assertEqual(obj.root * obj.root,
756 obj.root_squared)
757 self.assertGreaterEqual(obj.radicand,
758 obj.root_squared)
759 if obj.calculate_stage():
760 break
761 else:
762 self.fail("infinite loop")
763 self.assertEqual(obj.root * obj.root,
764 obj.root_squared)
765 self.assertGreaterEqual(obj.radicand,
766 obj.root_squared)
767 self.assertEqual(obj.remainder,
768 obj.radicand - obj.root_squared)
769 self.assertEqual(obj.root, root_remainder.root)
770 self.assertEqual(obj.remainder,
771 root_remainder.remainder)
772
773 def test_radix_2(self):
774 self.helper(1)
775
776 def test_radix_4(self):
777 self.helper(2)
778
779 def test_radix_8(self):
780 self.helper(3)
781
782 def test_radix_16(self):
783 self.helper(4)
784
785
786 class TestFixedRSqrtFn(unittest.TestCase):
787 def test2(self):
788 for bits in range(1, 1 << 5):
789 radicand = Fixed.from_bits(bits, 5, 12, False)
790 float_root = 1 / math.sqrt(float(radicand))
791 root = radicand.with_value(float_root)
792 remainder = 1 - root * root * radicand
793 expected = RootRemainder(root, remainder)
794 with self.subTest(radicand=repr(radicand),
795 expected=repr(expected)):
796 self.assertEqual(repr(fixed_rsqrt(radicand)),
797 repr(expected))
798
799 def test(self):
800 for signed in False, True:
801 for bit_width in range(1, 10):
802 for fract_width in range(bit_width):
803 for bits in range(1 << bit_width):
804 radicand = Fixed.from_bits(bits,
805 fract_width,
806 bit_width,
807 signed)
808 if radicand <= 0:
809 continue
810 float_root = 1 / math.sqrt(float(radicand))
811 max_value = radicand.with_bits(
812 (1 << (bit_width - signed)) - 1)
813 if float_root > float(max_value):
814 root = max_value
815 else:
816 root = radicand.with_value(float_root)
817 remainder = 1 - root * root * radicand
818 expected = RootRemainder(root, remainder)
819 with self.subTest(radicand=repr(radicand),
820 expected=repr(expected)):
821 self.assertEqual(repr(fixed_rsqrt(radicand)),
822 repr(expected))
823
824 def test_misc_cases(self):
825 test_cases = [
826 # radicand, expected
827 (Fixed(0.5, 30, 32, False),
828 "RootRemainder(fixed:0x1.6a09e664, "
829 "fixed:0x0.0000000596d014780000000)")
830 ]
831 for radicand, expected in test_cases:
832 with self.subTest(radicand=str(radicand), expected=expected):
833 self.assertEqual(str(fixed_rsqrt(radicand)), expected)
834
835
836 class TestFixedRSqrt(unittest.TestCase):
837 def helper(self, log2_radix):
838 for bit_width in range(1, 8):
839 for fract_width in range(bit_width):
840 for radicand_bits in range(1, 1 << bit_width):
841 radicand = Fixed.from_bits(radicand_bits,
842 fract_width,
843 bit_width,
844 False)
845 root_remainder = fixed_rsqrt(radicand)
846 with self.subTest(radicand=repr(radicand),
847 root_remainder=repr(root_remainder),
848 log2_radix=log2_radix):
849 obj = FixedRSqrt(radicand, log2_radix)
850 for _ in range(250 * bit_width):
851 self.assertEqual(obj.radicand * obj.root,
852 obj.radicand_root)
853 self.assertEqual(obj.radicand_root * obj.root,
854 obj.radicand_root_squared)
855 self.assertGreaterEqual(1,
856 obj.radicand_root_squared)
857 if obj.calculate_stage():
858 break
859 else:
860 self.fail("infinite loop")
861 self.assertEqual(obj.radicand * obj.root,
862 obj.radicand_root)
863 self.assertEqual(obj.radicand_root * obj.root,
864 obj.radicand_root_squared)
865 self.assertGreaterEqual(1,
866 obj.radicand_root_squared)
867 self.assertEqual(obj.remainder,
868 1 - obj.radicand_root_squared)
869 self.assertEqual(obj.root, root_remainder.root)
870 self.assertEqual(obj.remainder,
871 root_remainder.remainder)
872
873 def test_radix_2(self):
874 self.helper(1)
875
876 def test_radix_4(self):
877 self.helper(2)
878
879 def test_radix_8(self):
880 self.helper(3)
881
882 def test_radix_16(self):
883 self.helper(4)
884
885
886 class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
887 @staticmethod
888 def show_fixed(bits, fract_width, bit_width):
889 fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
890 return f"{str(fixed)}:{repr(fixed)}"
891
892 def check_invariants(self,
893 dividend,
894 divisor_radicand,
895 operation,
896 bit_width,
897 fract_width,
898 log2_radix,
899 obj):
900 self.assertEqual(obj.dividend, dividend)
901 self.assertEqual(obj.divisor_radicand, divisor_radicand)
902 self.assertEqual(obj.operation, operation)
903 self.assertEqual(obj.bit_width, bit_width)
904 self.assertEqual(obj.fract_width, fract_width)
905 self.assertEqual(obj.log2_radix, log2_radix)
906 self.assertEqual(obj.root_times_radicand,
907 obj.quotient_root * obj.divisor_radicand)
908 self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
909 self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
910 if operation is Operation.UDivRem:
911 self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
912 self.assertEqual(obj.compare_rhs,
913 (obj.quotient_root * obj.divisor_radicand)
914 << fract_width)
915 elif operation is Operation.SqrtRem:
916 self.assertEqual(obj.compare_lhs,
917 obj.divisor_radicand << (fract_width * 2))
918 self.assertEqual(obj.compare_rhs,
919 (obj.quotient_root * obj.quotient_root)
920 << fract_width)
921 else:
922 assert operation is Operation.RSqrtRem
923 self.assertEqual(obj.compare_lhs,
924 1 << (fract_width * 3))
925 self.assertEqual(obj.compare_rhs,
926 obj.quotient_root * obj.quotient_root
927 * obj.divisor_radicand)
928
929 def handle_case(self,
930 dividend,
931 divisor_radicand,
932 operation,
933 bit_width,
934 fract_width,
935 log2_radix):
936 dividend_str = self.show_fixed(dividend,
937 fract_width * 2,
938 bit_width + fract_width)
939 divisor_radicand_str = self.show_fixed(divisor_radicand,
940 fract_width,
941 bit_width)
942 with self.subTest(dividend=dividend_str,
943 divisor_radicand=divisor_radicand_str,
944 operation=operation.name,
945 bit_width=bit_width,
946 fract_width=fract_width,
947 log2_radix=log2_radix):
948 if operation is Operation.UDivRem:
949 if divisor_radicand == 0:
950 return
951 quotient_root, remainder = div_rem(dividend,
952 divisor_radicand,
953 bit_width * 3,
954 False)
955 remainder <<= fract_width
956 elif operation is Operation.SqrtRem:
957 root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
958 fract_width,
959 bit_width,
960 False))
961 self.assertEqual(root_remainder.root.bit_width,
962 bit_width)
963 self.assertEqual(root_remainder.root.fract_width,
964 fract_width)
965 self.assertEqual(root_remainder.remainder.bit_width,
966 bit_width * 2)
967 self.assertEqual(root_remainder.remainder.fract_width,
968 fract_width * 2)
969 quotient_root = root_remainder.root.bits
970 remainder = root_remainder.remainder.bits << fract_width
971 else:
972 assert operation is Operation.RSqrtRem
973 if divisor_radicand == 0:
974 return
975 root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
976 fract_width,
977 bit_width,
978 False))
979 self.assertEqual(root_remainder.root.bit_width,
980 bit_width)
981 self.assertEqual(root_remainder.root.fract_width,
982 fract_width)
983 self.assertEqual(root_remainder.remainder.bit_width,
984 bit_width * 3)
985 self.assertEqual(root_remainder.remainder.fract_width,
986 fract_width * 3)
987 quotient_root = root_remainder.root.bits
988 remainder = root_remainder.remainder.bits
989 if quotient_root >= (1 << bit_width):
990 return
991 quotient_root_str = self.show_fixed(quotient_root,
992 fract_width,
993 bit_width)
994 remainder_str = self.show_fixed(remainder,
995 fract_width * 3,
996 bit_width * 3)
997 with self.subTest(quotient_root=quotient_root_str,
998 remainder=remainder_str):
999 obj = FixedUDivRemSqrtRSqrt(dividend,
1000 divisor_radicand,
1001 operation,
1002 bit_width,
1003 fract_width,
1004 log2_radix)
1005 for _ in range(250 * bit_width):
1006 self.check_invariants(dividend,
1007 divisor_radicand,
1008 operation,
1009 bit_width,
1010 fract_width,
1011 log2_radix,
1012 obj)
1013 if obj.calculate_stage():
1014 break
1015 else:
1016 self.fail("infinite loop")
1017 self.check_invariants(dividend,
1018 divisor_radicand,
1019 operation,
1020 bit_width,
1021 fract_width,
1022 log2_radix,
1023 obj)
1024 self.assertEqual(obj.quotient_root, quotient_root)
1025 self.assertEqual(obj.remainder, remainder)
1026
1027 def helper(self, log2_radix, operation):
1028 bit_width_range = range(1, 8)
1029 if operation is Operation.UDivRem:
1030 bit_width_range = range(1, 6)
1031 for bit_width in bit_width_range:
1032 for fract_width in range(bit_width):
1033 for divisor_radicand in range(1 << bit_width):
1034 dividend_range = range(1)
1035 if operation is Operation.UDivRem:
1036 dividend_range = range(1 << (bit_width + fract_width))
1037 for dividend in dividend_range:
1038 self.handle_case(dividend,
1039 divisor_radicand,
1040 operation,
1041 bit_width,
1042 fract_width,
1043 log2_radix)
1044
1045 def test_radix_2_UDiv(self):
1046 self.helper(1, Operation.UDivRem)
1047
1048 def test_radix_4_UDiv(self):
1049 self.helper(2, Operation.UDivRem)
1050
1051 def test_radix_8_UDiv(self):
1052 self.helper(3, Operation.UDivRem)
1053
1054 def test_radix_16_UDiv(self):
1055 self.helper(4, Operation.UDivRem)
1056
1057 def test_radix_2_Sqrt(self):
1058 self.helper(1, Operation.SqrtRem)
1059
1060 def test_radix_4_Sqrt(self):
1061 self.helper(2, Operation.SqrtRem)
1062
1063 def test_radix_8_Sqrt(self):
1064 self.helper(3, Operation.SqrtRem)
1065
1066 def test_radix_16_Sqrt(self):
1067 self.helper(4, Operation.SqrtRem)
1068
1069 def test_radix_2_RSqrt(self):
1070 self.helper(1, Operation.RSqrtRem)
1071
1072 def test_radix_4_RSqrt(self):
1073 self.helper(2, Operation.RSqrtRem)
1074
1075 def test_radix_8_RSqrt(self):
1076 self.helper(3, Operation.RSqrtRem)
1077
1078 def test_radix_16_RSqrt(self):
1079 self.helper(4, Operation.RSqrtRem)
1080
1081 def test_int_div(self):
1082 bit_width = 8
1083 fract_width = 4
1084 log2_radix = 3
1085 for dividend in range(1 << bit_width):
1086 for divisor in range(1, 1 << bit_width):
1087 obj = FixedUDivRemSqrtRSqrt(dividend,
1088 divisor,
1089 Operation.UDivRem,
1090 bit_width,
1091 fract_width,
1092 log2_radix)
1093 obj.calculate()
1094 quotient, remainder = div_rem(dividend,
1095 divisor,
1096 bit_width,
1097 False)
1098 shifted_remainder = remainder << fract_width
1099 with self.subTest(dividend=dividend,
1100 divisor=divisor,
1101 quotient=quotient,
1102 remainder=remainder,
1103 shifted_remainder=shifted_remainder):
1104 self.assertEqual(obj.quotient_root, quotient)
1105 self.assertEqual(obj.remainder, shifted_remainder)
1106
1107 def test_fract_div(self):
1108 bit_width = 8
1109 fract_width = 4
1110 log2_radix = 3
1111 for dividend in range(1 << bit_width):
1112 for divisor in range(1, 1 << bit_width):
1113 obj = FixedUDivRemSqrtRSqrt(dividend << fract_width,
1114 divisor,
1115 Operation.UDivRem,
1116 bit_width,
1117 fract_width,
1118 log2_radix)
1119 obj.calculate()
1120 quotient = (dividend << fract_width) // divisor
1121 if quotient >= (1 << bit_width):
1122 continue
1123 remainder = (dividend << fract_width) % divisor
1124 shifted_remainder = remainder << fract_width
1125 with self.subTest(dividend=dividend,
1126 divisor=divisor,
1127 quotient=quotient,
1128 remainder=remainder,
1129 shifted_remainder=shifted_remainder):
1130 self.assertEqual(obj.quotient_root, quotient)
1131 self.assertEqual(obj.remainder, shifted_remainder)