change all uses of dataclass to plain_data
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen.hdl.ast import Const
6 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
7 Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
8 fixed_rsqrt, FixedRSqrt, Operation,
9 FixedUDivRemSqrtRSqrt)
10 import unittest
11 import math
12
13
14 class TestDivRemFn(unittest.TestCase):
15 def test_signed(self):
16 test_cases = [
17 # numerator, denominator, quotient, remainder
18 (-8, -8, 1, 0),
19 (-7, -8, 0, -7),
20 (-6, -8, 0, -6),
21 (-5, -8, 0, -5),
22 (-4, -8, 0, -4),
23 (-3, -8, 0, -3),
24 (-2, -8, 0, -2),
25 (-1, -8, 0, -1),
26 (0, -8, 0, 0),
27 (1, -8, 0, 1),
28 (2, -8, 0, 2),
29 (3, -8, 0, 3),
30 (4, -8, 0, 4),
31 (5, -8, 0, 5),
32 (6, -8, 0, 6),
33 (7, -8, 0, 7),
34 (-8, -7, 1, -1),
35 (-7, -7, 1, 0),
36 (-6, -7, 0, -6),
37 (-5, -7, 0, -5),
38 (-4, -7, 0, -4),
39 (-3, -7, 0, -3),
40 (-2, -7, 0, -2),
41 (-1, -7, 0, -1),
42 (0, -7, 0, 0),
43 (1, -7, 0, 1),
44 (2, -7, 0, 2),
45 (3, -7, 0, 3),
46 (4, -7, 0, 4),
47 (5, -7, 0, 5),
48 (6, -7, 0, 6),
49 (7, -7, -1, 0),
50 (-8, -6, 1, -2),
51 (-7, -6, 1, -1),
52 (-6, -6, 1, 0),
53 (-5, -6, 0, -5),
54 (-4, -6, 0, -4),
55 (-3, -6, 0, -3),
56 (-2, -6, 0, -2),
57 (-1, -6, 0, -1),
58 (0, -6, 0, 0),
59 (1, -6, 0, 1),
60 (2, -6, 0, 2),
61 (3, -6, 0, 3),
62 (4, -6, 0, 4),
63 (5, -6, 0, 5),
64 (6, -6, -1, 0),
65 (7, -6, -1, 1),
66 (-8, -5, 1, -3),
67 (-7, -5, 1, -2),
68 (-6, -5, 1, -1),
69 (-5, -5, 1, 0),
70 (-4, -5, 0, -4),
71 (-3, -5, 0, -3),
72 (-2, -5, 0, -2),
73 (-1, -5, 0, -1),
74 (0, -5, 0, 0),
75 (1, -5, 0, 1),
76 (2, -5, 0, 2),
77 (3, -5, 0, 3),
78 (4, -5, 0, 4),
79 (5, -5, -1, 0),
80 (6, -5, -1, 1),
81 (7, -5, -1, 2),
82 (-8, -4, 2, 0),
83 (-7, -4, 1, -3),
84 (-6, -4, 1, -2),
85 (-5, -4, 1, -1),
86 (-4, -4, 1, 0),
87 (-3, -4, 0, -3),
88 (-2, -4, 0, -2),
89 (-1, -4, 0, -1),
90 (0, -4, 0, 0),
91 (1, -4, 0, 1),
92 (2, -4, 0, 2),
93 (3, -4, 0, 3),
94 (4, -4, -1, 0),
95 (5, -4, -1, 1),
96 (6, -4, -1, 2),
97 (7, -4, -1, 3),
98 (-8, -3, 2, -2),
99 (-7, -3, 2, -1),
100 (-6, -3, 2, 0),
101 (-5, -3, 1, -2),
102 (-4, -3, 1, -1),
103 (-3, -3, 1, 0),
104 (-2, -3, 0, -2),
105 (-1, -3, 0, -1),
106 (0, -3, 0, 0),
107 (1, -3, 0, 1),
108 (2, -3, 0, 2),
109 (3, -3, -1, 0),
110 (4, -3, -1, 1),
111 (5, -3, -1, 2),
112 (6, -3, -2, 0),
113 (7, -3, -2, 1),
114 (-8, -2, 4, 0),
115 (-7, -2, 3, -1),
116 (-6, -2, 3, 0),
117 (-5, -2, 2, -1),
118 (-4, -2, 2, 0),
119 (-3, -2, 1, -1),
120 (-2, -2, 1, 0),
121 (-1, -2, 0, -1),
122 (0, -2, 0, 0),
123 (1, -2, 0, 1),
124 (2, -2, -1, 0),
125 (3, -2, -1, 1),
126 (4, -2, -2, 0),
127 (5, -2, -2, 1),
128 (6, -2, -3, 0),
129 (7, -2, -3, 1),
130 (-8, -1, -8, 0), # overflows and wraps around
131 (-7, -1, 7, 0),
132 (-6, -1, 6, 0),
133 (-5, -1, 5, 0),
134 (-4, -1, 4, 0),
135 (-3, -1, 3, 0),
136 (-2, -1, 2, 0),
137 (-1, -1, 1, 0),
138 (0, -1, 0, 0),
139 (1, -1, -1, 0),
140 (2, -1, -2, 0),
141 (3, -1, -3, 0),
142 (4, -1, -4, 0),
143 (5, -1, -5, 0),
144 (6, -1, -6, 0),
145 (7, -1, -7, 0),
146 (-8, 0, -1, -8),
147 (-7, 0, -1, -7),
148 (-6, 0, -1, -6),
149 (-5, 0, -1, -5),
150 (-4, 0, -1, -4),
151 (-3, 0, -1, -3),
152 (-2, 0, -1, -2),
153 (-1, 0, -1, -1),
154 (0, 0, -1, 0),
155 (1, 0, -1, 1),
156 (2, 0, -1, 2),
157 (3, 0, -1, 3),
158 (4, 0, -1, 4),
159 (5, 0, -1, 5),
160 (6, 0, -1, 6),
161 (7, 0, -1, 7),
162 (-8, 1, -8, 0),
163 (-7, 1, -7, 0),
164 (-6, 1, -6, 0),
165 (-5, 1, -5, 0),
166 (-4, 1, -4, 0),
167 (-3, 1, -3, 0),
168 (-2, 1, -2, 0),
169 (-1, 1, -1, 0),
170 (0, 1, 0, 0),
171 (1, 1, 1, 0),
172 (2, 1, 2, 0),
173 (3, 1, 3, 0),
174 (4, 1, 4, 0),
175 (5, 1, 5, 0),
176 (6, 1, 6, 0),
177 (7, 1, 7, 0),
178 (-8, 2, -4, 0),
179 (-7, 2, -3, -1),
180 (-6, 2, -3, 0),
181 (-5, 2, -2, -1),
182 (-4, 2, -2, 0),
183 (-3, 2, -1, -1),
184 (-2, 2, -1, 0),
185 (-1, 2, 0, -1),
186 (0, 2, 0, 0),
187 (1, 2, 0, 1),
188 (2, 2, 1, 0),
189 (3, 2, 1, 1),
190 (4, 2, 2, 0),
191 (5, 2, 2, 1),
192 (6, 2, 3, 0),
193 (7, 2, 3, 1),
194 (-8, 3, -2, -2),
195 (-7, 3, -2, -1),
196 (-6, 3, -2, 0),
197 (-5, 3, -1, -2),
198 (-4, 3, -1, -1),
199 (-3, 3, -1, 0),
200 (-2, 3, 0, -2),
201 (-1, 3, 0, -1),
202 (0, 3, 0, 0),
203 (1, 3, 0, 1),
204 (2, 3, 0, 2),
205 (3, 3, 1, 0),
206 (4, 3, 1, 1),
207 (5, 3, 1, 2),
208 (6, 3, 2, 0),
209 (7, 3, 2, 1),
210 (-8, 4, -2, 0),
211 (-7, 4, -1, -3),
212 (-6, 4, -1, -2),
213 (-5, 4, -1, -1),
214 (-4, 4, -1, 0),
215 (-3, 4, 0, -3),
216 (-2, 4, 0, -2),
217 (-1, 4, 0, -1),
218 (0, 4, 0, 0),
219 (1, 4, 0, 1),
220 (2, 4, 0, 2),
221 (3, 4, 0, 3),
222 (4, 4, 1, 0),
223 (5, 4, 1, 1),
224 (6, 4, 1, 2),
225 (7, 4, 1, 3),
226 (-8, 5, -1, -3),
227 (-7, 5, -1, -2),
228 (-6, 5, -1, -1),
229 (-5, 5, -1, 0),
230 (-4, 5, 0, -4),
231 (-3, 5, 0, -3),
232 (-2, 5, 0, -2),
233 (-1, 5, 0, -1),
234 (0, 5, 0, 0),
235 (1, 5, 0, 1),
236 (2, 5, 0, 2),
237 (3, 5, 0, 3),
238 (4, 5, 0, 4),
239 (5, 5, 1, 0),
240 (6, 5, 1, 1),
241 (7, 5, 1, 2),
242 (-8, 6, -1, -2),
243 (-7, 6, -1, -1),
244 (-6, 6, -1, 0),
245 (-5, 6, 0, -5),
246 (-4, 6, 0, -4),
247 (-3, 6, 0, -3),
248 (-2, 6, 0, -2),
249 (-1, 6, 0, -1),
250 (0, 6, 0, 0),
251 (1, 6, 0, 1),
252 (2, 6, 0, 2),
253 (3, 6, 0, 3),
254 (4, 6, 0, 4),
255 (5, 6, 0, 5),
256 (6, 6, 1, 0),
257 (7, 6, 1, 1),
258 (-8, 7, -1, -1),
259 (-7, 7, -1, 0),
260 (-6, 7, 0, -6),
261 (-5, 7, 0, -5),
262 (-4, 7, 0, -4),
263 (-3, 7, 0, -3),
264 (-2, 7, 0, -2),
265 (-1, 7, 0, -1),
266 (0, 7, 0, 0),
267 (1, 7, 0, 1),
268 (2, 7, 0, 2),
269 (3, 7, 0, 3),
270 (4, 7, 0, 4),
271 (5, 7, 0, 5),
272 (6, 7, 0, 6),
273 (7, 7, 1, 0),
274 ]
275 for (n, d, q, r) in test_cases:
276 self.assertEqual(div_rem(n, d, 4, True), (q, r))
277
278 def test_unsigned(self):
279 for n in range(16):
280 for d in range(16):
281 if d == 0:
282 q = 16 - 1
283 r = n
284 else:
285 # div_rem matches // and % for unsigned integers
286 q = n // d
287 r = n % d
288 self.assertEqual(div_rem(n, d, 4, False), (q, r))
289
290
291 class TestUnsignedDivRem(unittest.TestCase):
292 def helper(self, log2_radix):
293 bit_width = 4
294 for n in range(1 << bit_width):
295 for d in range(1 << bit_width):
296 q, r = div_rem(n, d, bit_width, False)
297 with self.subTest(n=n, d=d, q=q, r=r):
298 udr = UnsignedDivRem(n, d, bit_width, log2_radix)
299 for _ in range(250 * bit_width):
300 self.assertEqual(udr.dividend, n)
301 self.assertEqual(udr.divisor, d)
302 self.assertEqual(udr.quotient_times_divisor,
303 udr.quotient * udr.divisor)
304 self.assertGreaterEqual(udr.dividend,
305 udr.quotient_times_divisor)
306 if udr.calculate_stage():
307 break
308 else:
309 self.fail("infinite loop")
310 self.assertEqual(udr.dividend, n)
311 self.assertEqual(udr.divisor, d)
312 self.assertEqual(udr.quotient_times_divisor,
313 udr.quotient * udr.divisor)
314 self.assertGreaterEqual(udr.dividend,
315 udr.quotient_times_divisor)
316 self.assertEqual(udr.quotient, q)
317 self.assertEqual(udr.remainder, r)
318
319 def test_radix_2(self):
320 self.helper(1)
321
322 def test_radix_4(self):
323 self.helper(2)
324
325 def test_radix_8(self):
326 self.helper(3)
327
328 def test_radix_16(self):
329 self.helper(4)
330
331
332 class TestDivRem(unittest.TestCase):
333 def helper(self, log2_radix):
334 bit_width = 4
335 for n in range(1 << bit_width):
336 for d in range(1 << bit_width):
337 for signed in False, True:
338 n = Const.normalize(n, (bit_width, signed))
339 d = Const.normalize(d, (bit_width, signed))
340 q, r = div_rem(n, d, bit_width, signed)
341 with self.subTest(n=n, d=d, q=q, r=r, signed=signed):
342 dr = DivRem(n, d, bit_width, signed, log2_radix)
343 for _ in range(250 * bit_width):
344 if dr.calculate_stage():
345 break
346 else:
347 self.fail("infinite loop")
348 self.assertEqual(dr.quotient, q)
349 self.assertEqual(dr.remainder, r)
350
351 def test_radix_2(self):
352 self.helper(1)
353
354 def test_radix_4(self):
355 self.helper(2)
356
357 def test_radix_8(self):
358 self.helper(3)
359
360 def test_radix_16(self):
361 self.helper(4)
362
363
364 class TestFixed(unittest.TestCase):
365 def test_constructor(self):
366 value = Fixed(0, 0, 1, False)
367 self.assertEqual(value.bits, 0)
368 self.assertEqual(value.fract_width, 0)
369 self.assertEqual(value.bit_width, 1)
370 self.assertEqual(value.signed, False)
371 value = Fixed(1, 2, 3, True)
372 self.assertEqual(value.bits, -4)
373 self.assertEqual(value.fract_width, 2)
374 self.assertEqual(value.bit_width, 3)
375 self.assertEqual(value.signed, True)
376 value = Fixed(1, 2, 4, True)
377 self.assertEqual(value.bits, 4)
378 self.assertEqual(value.fract_width, 2)
379 self.assertEqual(value.bit_width, 4)
380 self.assertEqual(value.signed, True)
381 value = Fixed(1.25, 4, 8, True)
382 self.assertEqual(value.bits, 0x14)
383 self.assertEqual(value.fract_width, 4)
384 self.assertEqual(value.bit_width, 8)
385 self.assertEqual(value.signed, True)
386 value = Fixed(Fixed(2, 0, 12, False), 4, 8, True)
387 self.assertEqual(value.bits, 0x20)
388 self.assertEqual(value.fract_width, 4)
389 self.assertEqual(value.bit_width, 8)
390 self.assertEqual(value.signed, True)
391 value = Fixed(0x2FF / 2 ** 8, 8, 12, False)
392 self.assertEqual(value.bits, 0x2FF)
393 self.assertEqual(value.fract_width, 8)
394 self.assertEqual(value.bit_width, 12)
395 self.assertEqual(value.signed, False)
396 value = Fixed(value, 4, 8, True)
397 self.assertEqual(value.bits, 0x2F)
398 self.assertEqual(value.fract_width, 4)
399 self.assertEqual(value.bit_width, 8)
400 self.assertEqual(value.signed, True)
401
402 def helper_tst_from_bits(self, bit_width, fract_width):
403 signed = False
404 for bits in range(1 << bit_width):
405 with self.subTest(bit_width=bit_width,
406 fract_width=fract_width,
407 signed=signed,
408 bits=hex(bits)):
409 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
410 self.assertEqual(value.bit_width, bit_width)
411 self.assertEqual(value.fract_width, fract_width)
412 self.assertEqual(value.signed, signed)
413 self.assertEqual(value.bits, bits)
414 signed = True
415 for bits in range(-1 << (bit_width - 1), 1 << (bit_width - 1)):
416 with self.subTest(bit_width=bit_width,
417 fract_width=fract_width,
418 signed=signed,
419 bits=hex(bits)):
420 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
421 self.assertEqual(value.bit_width, bit_width)
422 self.assertEqual(value.fract_width, fract_width)
423 self.assertEqual(value.signed, signed)
424 self.assertEqual(value.bits, bits)
425
426 def test_from_bits(self):
427 for bit_width in range(1, 5):
428 for fract_width in range(bit_width):
429 self.helper_tst_from_bits(bit_width, fract_width)
430
431 def test_repr(self):
432 self.assertEqual(repr(Fixed.from_bits(1, 2, 3, False)),
433 "Fixed.from_bits(1, 2, 3, False)")
434 self.assertEqual(repr(Fixed.from_bits(-4, 2, 3, True)),
435 "Fixed.from_bits(-4, 2, 3, True)")
436 self.assertEqual(repr(Fixed.from_bits(-4, 7, 10, True)),
437 "Fixed.from_bits(-4, 7, 10, True)")
438
439 def test_trunc(self):
440 for i in range(-8, 8):
441 value = Fixed.from_bits(i, 2, 4, True)
442 with self.subTest(value=repr(value)):
443 self.assertEqual(math.trunc(value), math.trunc(i / 4))
444
445 def test_int(self):
446 for i in range(-8, 8):
447 value = Fixed.from_bits(i, 2, 4, True)
448 with self.subTest(value=repr(value)):
449 self.assertEqual(int(value), math.trunc(value))
450
451 def test_float(self):
452 for i in range(-8, 8):
453 value = Fixed.from_bits(i, 2, 4, True)
454 with self.subTest(value=repr(value)):
455 self.assertEqual(float(value), i / 4)
456
457 def test_floor(self):
458 for i in range(-8, 8):
459 value = Fixed.from_bits(i, 2, 4, True)
460 with self.subTest(value=repr(value)):
461 self.assertEqual(math.floor(value), math.floor(i / 4))
462
463 def test_ceil(self):
464 for i in range(-8, 8):
465 value = Fixed.from_bits(i, 2, 4, True)
466 with self.subTest(value=repr(value)):
467 self.assertEqual(math.ceil(value), math.ceil(i / 4))
468
469 def test_neg(self):
470 for i in range(-8, 8):
471 value = Fixed.from_bits(i, 2, 4, True)
472 expected = -i / 4 if i != -8 else -2.0 # handle wrap-around
473 with self.subTest(value=repr(value)):
474 self.assertEqual(float(-value), expected)
475
476 def test_pos(self):
477 for i in range(-8, 8):
478 value = Fixed.from_bits(i, 2, 4, True)
479 with self.subTest(value=repr(value)):
480 value = +value
481 self.assertEqual(value.bits, i)
482
483 def test_abs(self):
484 for i in range(-8, 8):
485 value = Fixed.from_bits(i, 2, 4, True)
486 expected = abs(i) / 4 if i != -8 else -2.0 # handle wrap-around
487 with self.subTest(value=repr(value)):
488 self.assertEqual(float(abs(value)), expected)
489
490 def test_not(self):
491 for i in range(-8, 8):
492 value = Fixed.from_bits(i, 2, 4, True)
493 with self.subTest(value=repr(value)):
494 self.assertEqual(float(~value), (~i) / 4)
495
496 @staticmethod
497 def get_test_values(max_bit_width, include_int):
498 for signed in False, True:
499 if include_int:
500 for bits in range(1 << max_bit_width):
501 int_value = Const.normalize(bits, (max_bit_width, signed))
502 yield int_value
503 for bit_width in range(1, max_bit_width):
504 for fract_width in range(bit_width + 1):
505 for bits in range(1 << bit_width):
506 yield Fixed.from_bits(bits,
507 fract_width,
508 bit_width,
509 signed)
510
511 def binary_op_test_helper(self,
512 operation,
513 is_fixed=True,
514 width_combine_op=max,
515 adjust_bits_op=None):
516 def default_adjust_bits_op(bits, out_fract_width, in_fract_width):
517 return bits << (out_fract_width - in_fract_width)
518 if adjust_bits_op is None:
519 adjust_bits_op = default_adjust_bits_op
520 max_bit_width = 5
521 for lhs in self.get_test_values(max_bit_width, True):
522 lhs_is_int = isinstance(lhs, int)
523 for rhs in self.get_test_values(max_bit_width, not lhs_is_int):
524 rhs_is_int = isinstance(rhs, int)
525 if lhs_is_int:
526 assert not rhs_is_int
527 lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0)
528 int_result = operation(lhs_int, rhs.bits)
529 if is_fixed:
530 expected = Fixed.from_bits(int_result,
531 rhs.fract_width,
532 rhs.bit_width,
533 rhs.signed)
534 else:
535 expected = int_result
536 elif rhs_is_int:
537 rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0)
538 int_result = operation(lhs.bits, rhs_int)
539 if is_fixed:
540 expected = Fixed.from_bits(int_result,
541 lhs.fract_width,
542 lhs.bit_width,
543 lhs.signed)
544 else:
545 expected = int_result
546 elif lhs.signed != rhs.signed:
547 continue
548 else:
549 fract_width = width_combine_op(lhs.fract_width,
550 rhs.fract_width)
551 int_width = width_combine_op(lhs.bit_width
552 - lhs.fract_width,
553 rhs.bit_width
554 - rhs.fract_width)
555 bit_width = fract_width + int_width
556 lhs_int = adjust_bits_op(lhs.bits,
557 fract_width,
558 lhs.fract_width)
559 rhs_int = adjust_bits_op(rhs.bits,
560 fract_width,
561 rhs.fract_width)
562 int_result = operation(lhs_int, rhs_int)
563 if is_fixed:
564 expected = Fixed.from_bits(int_result,
565 fract_width,
566 bit_width,
567 lhs.signed)
568 else:
569 expected = int_result
570 with self.subTest(lhs=repr(lhs),
571 rhs=repr(rhs),
572 expected=repr(expected)):
573 result = operation(lhs, rhs)
574 if is_fixed:
575 self.assertEqual(result.bit_width, expected.bit_width)
576 self.assertEqual(result.signed, expected.signed)
577 self.assertEqual(result.fract_width,
578 expected.fract_width)
579 self.assertEqual(result.bits, expected.bits)
580 else:
581 self.assertEqual(result, expected)
582
583 def test_add(self):
584 self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs)
585
586 def test_sub(self):
587 self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs)
588
589 def test_and(self):
590 self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs)
591
592 def test_or(self):
593 self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs)
594
595 def test_xor(self):
596 self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs)
597
598 def test_mul(self):
599 def adjust_bits_op(bits, out_fract_width, in_fract_width):
600 return bits
601 self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs,
602 True,
603 lambda l_width, r_width: l_width + r_width,
604 adjust_bits_op)
605
606 def test_cmp(self):
607 def cmp(lhs, rhs):
608 if lhs < rhs:
609 return -1
610 elif lhs > rhs:
611 return 1
612 return 0
613 self.binary_op_test_helper(cmp, False)
614
615 def test_lt(self):
616 self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False)
617
618 def test_le(self):
619 self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False)
620
621 def test_eq(self):
622 self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False)
623
624 def test_ne(self):
625 self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False)
626
627 def test_gt(self):
628 self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False)
629
630 def test_ge(self):
631 self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False)
632
633 def test_bool(self):
634 for v in self.get_test_values(6, False):
635 with self.subTest(v=repr(v)):
636 self.assertEqual(bool(v), bool(v.bits))
637
638 def test_str(self):
639 self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)),
640 "fixed:0x1234.")
641 self.assertEqual(str(Fixed.from_bits(-0x1234, 0, 16, True)),
642 "fixed:-0x1234.")
643 self.assertEqual(str(Fixed.from_bits(0x12345, 3, 20, True)),
644 "fixed:0x2468.a")
645 self.assertEqual(str(Fixed(123.625, 3, 12, True)),
646 "fixed:0x7b.a")
647
648 self.assertEqual(str(Fixed.from_bits(0x1, 0, 20, True)),
649 "fixed:0x1.")
650 self.assertEqual(str(Fixed.from_bits(0x2, 1, 20, True)),
651 "fixed:0x1.0")
652 self.assertEqual(str(Fixed.from_bits(0x4, 2, 20, True)),
653 "fixed:0x1.0")
654 self.assertEqual(str(Fixed.from_bits(0x9, 3, 20, True)),
655 "fixed:0x1.2")
656 self.assertEqual(str(Fixed.from_bits(0x12, 4, 20, True)),
657 "fixed:0x1.2")
658 self.assertEqual(str(Fixed.from_bits(0x24, 5, 20, True)),
659 "fixed:0x1.20")
660 self.assertEqual(str(Fixed.from_bits(0x48, 6, 20, True)),
661 "fixed:0x1.20")
662 self.assertEqual(str(Fixed.from_bits(0x91, 7, 20, True)),
663 "fixed:0x1.22")
664 self.assertEqual(str(Fixed.from_bits(0x123, 8, 20, True)),
665 "fixed:0x1.23")
666 self.assertEqual(str(Fixed.from_bits(0x246, 9, 20, True)),
667 "fixed:0x1.230")
668 self.assertEqual(str(Fixed.from_bits(0x48d, 10, 20, True)),
669 "fixed:0x1.234")
670 self.assertEqual(str(Fixed.from_bits(0x91a, 11, 20, True)),
671 "fixed:0x1.234")
672 self.assertEqual(str(Fixed.from_bits(0x1234, 12, 20, True)),
673 "fixed:0x1.234")
674 self.assertEqual(str(Fixed.from_bits(0x2468, 13, 20, True)),
675 "fixed:0x1.2340")
676 self.assertEqual(str(Fixed.from_bits(0x48d1, 14, 20, True)),
677 "fixed:0x1.2344")
678 self.assertEqual(str(Fixed.from_bits(0x91a2, 15, 20, True)),
679 "fixed:0x1.2344")
680 self.assertEqual(str(Fixed.from_bits(0x12345, 16, 20, True)),
681 "fixed:0x1.2345")
682 self.assertEqual(str(Fixed.from_bits(0x2468a, 17, 20, True)),
683 "fixed:0x1.23450")
684 self.assertEqual(str(Fixed.from_bits(0x48d14, 18, 20, True)),
685 "fixed:0x1.23450")
686 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, True)),
687 "fixed:-0x0.dcbb0")
688 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, False)),
689 "fixed:0x1.23450")
690
691
692 class TestFixedSqrtFn(unittest.TestCase):
693 def test_on_ints(self):
694 for radicand in range(-1, 32):
695 if radicand < 0:
696 expected = None
697 else:
698 root = math.floor(math.sqrt(radicand))
699 remainder = radicand - root * root
700 expected = RootRemainder(root, remainder)
701 with self.subTest(radicand=radicand, expected=expected):
702 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
703 radicand = 2 << 64
704 root = 0x16A09E667
705 remainder = radicand - root * root
706 expected = RootRemainder(root, remainder)
707 with self.subTest(radicand=radicand, expected=expected):
708 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
709
710 def test_on_fixed(self):
711 for signed in False, True:
712 for bit_width in range(1, 10):
713 for fract_width in range(bit_width):
714 for bits in range(1 << bit_width):
715 radicand = Fixed.from_bits(bits,
716 fract_width,
717 bit_width,
718 signed)
719 if radicand < 0:
720 continue
721 root = radicand.with_value(math.sqrt(float(radicand)))
722 remainder = radicand - root * root
723 expected = RootRemainder(root, remainder)
724 with self.subTest(radicand=repr(radicand),
725 expected=repr(expected)):
726 self.assertEqual(repr(fixed_sqrt(radicand)),
727 repr(expected))
728
729 def test_misc_cases(self):
730 test_cases = [
731 # radicand, expected
732 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
733 (Fixed(2, 30, 32, False),
734 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
735 ]
736 for radicand, expected in test_cases:
737 with self.subTest(radicand=str(radicand), expected=expected):
738 self.assertEqual(str(fixed_sqrt(radicand)), expected)
739
740
741 class TestFixedSqrt(unittest.TestCase):
742 def helper(self, log2_radix):
743 for bit_width in range(1, 8):
744 for fract_width in range(bit_width):
745 for radicand_bits in range(1 << bit_width):
746 radicand = Fixed.from_bits(radicand_bits,
747 fract_width,
748 bit_width,
749 False)
750 root_remainder = fixed_sqrt(radicand)
751 with self.subTest(radicand=repr(radicand),
752 root_remainder=repr(root_remainder),
753 log2_radix=log2_radix):
754 obj = FixedSqrt(radicand, log2_radix)
755 for _ in range(250 * bit_width):
756 self.assertEqual(obj.root * obj.root,
757 obj.root_squared)
758 self.assertGreaterEqual(obj.radicand,
759 obj.root_squared)
760 if obj.calculate_stage():
761 break
762 else:
763 self.fail("infinite loop")
764 self.assertEqual(obj.root * obj.root,
765 obj.root_squared)
766 self.assertGreaterEqual(obj.radicand,
767 obj.root_squared)
768 self.assertEqual(obj.remainder,
769 obj.radicand - obj.root_squared)
770 self.assertEqual(obj.root, root_remainder.root)
771 self.assertEqual(obj.remainder,
772 root_remainder.remainder)
773
774 def test_radix_2(self):
775 self.helper(1)
776
777 def test_radix_4(self):
778 self.helper(2)
779
780 def test_radix_8(self):
781 self.helper(3)
782
783 def test_radix_16(self):
784 self.helper(4)
785
786
787 class TestFixedRSqrtFn(unittest.TestCase):
788 def test2(self):
789 for bits in range(1, 1 << 5):
790 radicand = Fixed.from_bits(bits, 5, 12, False)
791 float_root = 1 / math.sqrt(float(radicand))
792 root = radicand.with_value(float_root)
793 remainder = 1 - root * root * radicand
794 expected = RootRemainder(root, remainder)
795 with self.subTest(radicand=repr(radicand),
796 expected=repr(expected)):
797 self.assertEqual(repr(fixed_rsqrt(radicand)),
798 repr(expected))
799
800 def test(self):
801 for signed in False, True:
802 for bit_width in range(1, 10):
803 for fract_width in range(bit_width):
804 for bits in range(1 << bit_width):
805 radicand = Fixed.from_bits(bits,
806 fract_width,
807 bit_width,
808 signed)
809 if radicand <= 0:
810 continue
811 float_root = 1 / math.sqrt(float(radicand))
812 max_value = radicand.with_bits(
813 (1 << (bit_width - signed)) - 1)
814 if float_root > float(max_value):
815 root = max_value
816 else:
817 root = radicand.with_value(float_root)
818 remainder = 1 - root * root * radicand
819 expected = RootRemainder(root, remainder)
820 with self.subTest(radicand=repr(radicand),
821 expected=repr(expected)):
822 self.assertEqual(repr(fixed_rsqrt(radicand)),
823 repr(expected))
824
825 def test_misc_cases(self):
826 test_cases = [
827 # radicand, expected
828 (Fixed(0.5, 30, 32, False),
829 "RootRemainder(fixed:0x1.6a09e664, "
830 "fixed:0x0.0000000596d014780000000)")
831 ]
832 for radicand, expected in test_cases:
833 with self.subTest(radicand=str(radicand), expected=expected):
834 self.assertEqual(str(fixed_rsqrt(radicand)), expected)
835
836
837 class TestFixedRSqrt(unittest.TestCase):
838 def helper(self, log2_radix):
839 for bit_width in range(1, 8):
840 for fract_width in range(bit_width):
841 for radicand_bits in range(1, 1 << bit_width):
842 radicand = Fixed.from_bits(radicand_bits,
843 fract_width,
844 bit_width,
845 False)
846 root_remainder = fixed_rsqrt(radicand)
847 with self.subTest(radicand=repr(radicand),
848 root_remainder=repr(root_remainder),
849 log2_radix=log2_radix):
850 obj = FixedRSqrt(radicand, log2_radix)
851 for _ in range(250 * bit_width):
852 self.assertEqual(obj.radicand * obj.root,
853 obj.radicand_root)
854 self.assertEqual(obj.radicand_root * obj.root,
855 obj.radicand_root_squared)
856 self.assertGreaterEqual(1,
857 obj.radicand_root_squared)
858 if obj.calculate_stage():
859 break
860 else:
861 self.fail("infinite loop")
862 self.assertEqual(obj.radicand * obj.root,
863 obj.radicand_root)
864 self.assertEqual(obj.radicand_root * obj.root,
865 obj.radicand_root_squared)
866 self.assertGreaterEqual(1,
867 obj.radicand_root_squared)
868 self.assertEqual(obj.remainder,
869 1 - obj.radicand_root_squared)
870 self.assertEqual(obj.root, root_remainder.root)
871 self.assertEqual(obj.remainder,
872 root_remainder.remainder)
873
874 def test_radix_2(self):
875 self.helper(1)
876
877 def test_radix_4(self):
878 self.helper(2)
879
880 def test_radix_8(self):
881 self.helper(3)
882
883 def test_radix_16(self):
884 self.helper(4)
885
886
887 class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
888 @staticmethod
889 def show_fixed(bits, fract_width, bit_width):
890 fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
891 return f"{str(fixed)}:{repr(fixed)}"
892
893 def check_invariants(self,
894 dividend,
895 divisor_radicand,
896 operation,
897 bit_width,
898 fract_width,
899 log2_radix,
900 obj):
901 self.assertEqual(obj.dividend, dividend)
902 self.assertEqual(obj.divisor_radicand, divisor_radicand)
903 self.assertEqual(obj.operation, operation)
904 self.assertEqual(obj.bit_width, bit_width)
905 self.assertEqual(obj.fract_width, fract_width)
906 self.assertEqual(obj.log2_radix, log2_radix)
907 self.assertEqual(obj.root_times_radicand,
908 obj.quotient_root * obj.divisor_radicand)
909 self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
910 self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
911 if operation is Operation.UDivRem:
912 self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
913 self.assertEqual(obj.compare_rhs,
914 (obj.quotient_root * obj.divisor_radicand)
915 << fract_width)
916 elif operation is Operation.SqrtRem:
917 self.assertEqual(obj.compare_lhs,
918 obj.divisor_radicand << (fract_width * 2))
919 self.assertEqual(obj.compare_rhs,
920 (obj.quotient_root * obj.quotient_root)
921 << fract_width)
922 else:
923 assert operation is Operation.RSqrtRem
924 self.assertEqual(obj.compare_lhs,
925 1 << (fract_width * 3))
926 self.assertEqual(obj.compare_rhs,
927 obj.quotient_root * obj.quotient_root
928 * obj.divisor_radicand)
929
930 def handle_case(self,
931 dividend,
932 divisor_radicand,
933 operation,
934 bit_width,
935 fract_width,
936 log2_radix):
937 dividend_str = self.show_fixed(dividend,
938 fract_width * 2,
939 bit_width + fract_width)
940 divisor_radicand_str = self.show_fixed(divisor_radicand,
941 fract_width,
942 bit_width)
943 with self.subTest(dividend=dividend_str,
944 divisor_radicand=divisor_radicand_str,
945 operation=operation.name,
946 bit_width=bit_width,
947 fract_width=fract_width,
948 log2_radix=log2_radix):
949 if operation is Operation.UDivRem:
950 if divisor_radicand == 0:
951 return
952 quotient_root, remainder = div_rem(dividend,
953 divisor_radicand,
954 bit_width * 3,
955 False)
956 remainder <<= fract_width
957 elif operation is Operation.SqrtRem:
958 root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
959 fract_width,
960 bit_width,
961 False))
962 self.assertEqual(root_remainder.root.bit_width,
963 bit_width)
964 self.assertEqual(root_remainder.root.fract_width,
965 fract_width)
966 self.assertEqual(root_remainder.remainder.bit_width,
967 bit_width * 2)
968 self.assertEqual(root_remainder.remainder.fract_width,
969 fract_width * 2)
970 quotient_root = root_remainder.root.bits
971 remainder = root_remainder.remainder.bits << fract_width
972 else:
973 assert operation is Operation.RSqrtRem
974 if divisor_radicand == 0:
975 return
976 root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
977 fract_width,
978 bit_width,
979 False))
980 self.assertEqual(root_remainder.root.bit_width,
981 bit_width)
982 self.assertEqual(root_remainder.root.fract_width,
983 fract_width)
984 self.assertEqual(root_remainder.remainder.bit_width,
985 bit_width * 3)
986 self.assertEqual(root_remainder.remainder.fract_width,
987 fract_width * 3)
988 quotient_root = root_remainder.root.bits
989 remainder = root_remainder.remainder.bits
990 if quotient_root >= (1 << bit_width):
991 return
992 quotient_root_str = self.show_fixed(quotient_root,
993 fract_width,
994 bit_width)
995 remainder_str = self.show_fixed(remainder,
996 fract_width * 3,
997 bit_width * 3)
998 with self.subTest(quotient_root=quotient_root_str,
999 remainder=remainder_str):
1000 obj = FixedUDivRemSqrtRSqrt(dividend,
1001 divisor_radicand,
1002 operation,
1003 bit_width,
1004 fract_width,
1005 log2_radix)
1006 for _ in range(250 * bit_width):
1007 self.check_invariants(dividend,
1008 divisor_radicand,
1009 operation,
1010 bit_width,
1011 fract_width,
1012 log2_radix,
1013 obj)
1014 if obj.calculate_stage():
1015 break
1016 else:
1017 self.fail("infinite loop")
1018 self.check_invariants(dividend,
1019 divisor_radicand,
1020 operation,
1021 bit_width,
1022 fract_width,
1023 log2_radix,
1024 obj)
1025 self.assertEqual(obj.quotient_root, quotient_root)
1026 self.assertEqual(obj.remainder, remainder)
1027
1028 def helper(self, log2_radix, operation):
1029 bit_width_range = range(1, 8)
1030 if operation is Operation.UDivRem:
1031 bit_width_range = range(1, 6)
1032 for bit_width in bit_width_range:
1033 for fract_width in range(bit_width):
1034 for divisor_radicand in range(1 << bit_width):
1035 dividend_range = range(1)
1036 if operation is Operation.UDivRem:
1037 dividend_range = range(1 << (bit_width + fract_width))
1038 for dividend in dividend_range:
1039 self.handle_case(dividend,
1040 divisor_radicand,
1041 operation,
1042 bit_width,
1043 fract_width,
1044 log2_radix)
1045
1046 def test_radix_2_UDiv(self):
1047 self.helper(1, Operation.UDivRem)
1048
1049 def test_radix_4_UDiv(self):
1050 self.helper(2, Operation.UDivRem)
1051
1052 def test_radix_8_UDiv(self):
1053 self.helper(3, Operation.UDivRem)
1054
1055 def test_radix_16_UDiv(self):
1056 self.helper(4, Operation.UDivRem)
1057
1058 def test_radix_2_Sqrt(self):
1059 self.helper(1, Operation.SqrtRem)
1060
1061 def test_radix_4_Sqrt(self):
1062 self.helper(2, Operation.SqrtRem)
1063
1064 def test_radix_8_Sqrt(self):
1065 self.helper(3, Operation.SqrtRem)
1066
1067 def test_radix_16_Sqrt(self):
1068 self.helper(4, Operation.SqrtRem)
1069
1070 def test_radix_2_RSqrt(self):
1071 self.helper(1, Operation.RSqrtRem)
1072
1073 def test_radix_4_RSqrt(self):
1074 self.helper(2, Operation.RSqrtRem)
1075
1076 def test_radix_8_RSqrt(self):
1077 self.helper(3, Operation.RSqrtRem)
1078
1079 def test_radix_16_RSqrt(self):
1080 self.helper(4, Operation.RSqrtRem)
1081
1082 def test_int_div(self):
1083 bit_width = 8
1084 fract_width = 4
1085 log2_radix = 3
1086 for dividend in range(1 << bit_width):
1087 for divisor in range(1, 1 << bit_width):
1088 obj = FixedUDivRemSqrtRSqrt(dividend,
1089 divisor,
1090 Operation.UDivRem,
1091 bit_width,
1092 fract_width,
1093 log2_radix)
1094 obj.calculate()
1095 quotient, remainder = div_rem(dividend,
1096 divisor,
1097 bit_width,
1098 False)
1099 shifted_remainder = remainder << fract_width
1100 with self.subTest(dividend=dividend,
1101 divisor=divisor,
1102 quotient=quotient,
1103 remainder=remainder,
1104 shifted_remainder=shifted_remainder):
1105 self.assertEqual(obj.quotient_root, quotient)
1106 self.assertEqual(obj.remainder, shifted_remainder)
1107
1108 def test_fract_div(self):
1109 bit_width = 8
1110 fract_width = 4
1111 log2_radix = 3
1112 for dividend in range(1 << bit_width):
1113 for divisor in range(1, 1 << bit_width):
1114 obj = FixedUDivRemSqrtRSqrt(dividend << fract_width,
1115 divisor,
1116 Operation.UDivRem,
1117 bit_width,
1118 fract_width,
1119 log2_radix)
1120 obj.calculate()
1121 quotient = (dividend << fract_width) // divisor
1122 if quotient >= (1 << bit_width):
1123 continue
1124 remainder = (dividend << fract_width) % divisor
1125 shifted_remainder = remainder << fract_width
1126 with self.subTest(dividend=dividend,
1127 divisor=divisor,
1128 quotient=quotient,
1129 remainder=remainder,
1130 shifted_remainder=shifted_remainder):
1131 self.assertEqual(obj.quotient_root, quotient)
1132 self.assertEqual(obj.remainder, shifted_remainder)
1133
1134
1135 if __name__ == '__main__':
1136 unittest.main()