implemented FixedUDivRemSqrtRSqrt
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 from nmigen.hdl.ast import Const
5 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
6 Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
7 fixed_rsqrt, FixedRSqrt, Operation,
8 FixedUDivRemSqrtRSqrt)
9 import unittest
10 import math
11
12
13 class TestDivRemFn(unittest.TestCase):
14 def test_signed(self):
15 test_cases = [
16 # numerator, denominator, quotient, remainder
17 (-8, -8, 1, 0),
18 (-7, -8, 0, -7),
19 (-6, -8, 0, -6),
20 (-5, -8, 0, -5),
21 (-4, -8, 0, -4),
22 (-3, -8, 0, -3),
23 (-2, -8, 0, -2),
24 (-1, -8, 0, -1),
25 (0, -8, 0, 0),
26 (1, -8, 0, 1),
27 (2, -8, 0, 2),
28 (3, -8, 0, 3),
29 (4, -8, 0, 4),
30 (5, -8, 0, 5),
31 (6, -8, 0, 6),
32 (7, -8, 0, 7),
33 (-8, -7, 1, -1),
34 (-7, -7, 1, 0),
35 (-6, -7, 0, -6),
36 (-5, -7, 0, -5),
37 (-4, -7, 0, -4),
38 (-3, -7, 0, -3),
39 (-2, -7, 0, -2),
40 (-1, -7, 0, -1),
41 (0, -7, 0, 0),
42 (1, -7, 0, 1),
43 (2, -7, 0, 2),
44 (3, -7, 0, 3),
45 (4, -7, 0, 4),
46 (5, -7, 0, 5),
47 (6, -7, 0, 6),
48 (7, -7, -1, 0),
49 (-8, -6, 1, -2),
50 (-7, -6, 1, -1),
51 (-6, -6, 1, 0),
52 (-5, -6, 0, -5),
53 (-4, -6, 0, -4),
54 (-3, -6, 0, -3),
55 (-2, -6, 0, -2),
56 (-1, -6, 0, -1),
57 (0, -6, 0, 0),
58 (1, -6, 0, 1),
59 (2, -6, 0, 2),
60 (3, -6, 0, 3),
61 (4, -6, 0, 4),
62 (5, -6, 0, 5),
63 (6, -6, -1, 0),
64 (7, -6, -1, 1),
65 (-8, -5, 1, -3),
66 (-7, -5, 1, -2),
67 (-6, -5, 1, -1),
68 (-5, -5, 1, 0),
69 (-4, -5, 0, -4),
70 (-3, -5, 0, -3),
71 (-2, -5, 0, -2),
72 (-1, -5, 0, -1),
73 (0, -5, 0, 0),
74 (1, -5, 0, 1),
75 (2, -5, 0, 2),
76 (3, -5, 0, 3),
77 (4, -5, 0, 4),
78 (5, -5, -1, 0),
79 (6, -5, -1, 1),
80 (7, -5, -1, 2),
81 (-8, -4, 2, 0),
82 (-7, -4, 1, -3),
83 (-6, -4, 1, -2),
84 (-5, -4, 1, -1),
85 (-4, -4, 1, 0),
86 (-3, -4, 0, -3),
87 (-2, -4, 0, -2),
88 (-1, -4, 0, -1),
89 (0, -4, 0, 0),
90 (1, -4, 0, 1),
91 (2, -4, 0, 2),
92 (3, -4, 0, 3),
93 (4, -4, -1, 0),
94 (5, -4, -1, 1),
95 (6, -4, -1, 2),
96 (7, -4, -1, 3),
97 (-8, -3, 2, -2),
98 (-7, -3, 2, -1),
99 (-6, -3, 2, 0),
100 (-5, -3, 1, -2),
101 (-4, -3, 1, -1),
102 (-3, -3, 1, 0),
103 (-2, -3, 0, -2),
104 (-1, -3, 0, -1),
105 (0, -3, 0, 0),
106 (1, -3, 0, 1),
107 (2, -3, 0, 2),
108 (3, -3, -1, 0),
109 (4, -3, -1, 1),
110 (5, -3, -1, 2),
111 (6, -3, -2, 0),
112 (7, -3, -2, 1),
113 (-8, -2, 4, 0),
114 (-7, -2, 3, -1),
115 (-6, -2, 3, 0),
116 (-5, -2, 2, -1),
117 (-4, -2, 2, 0),
118 (-3, -2, 1, -1),
119 (-2, -2, 1, 0),
120 (-1, -2, 0, -1),
121 (0, -2, 0, 0),
122 (1, -2, 0, 1),
123 (2, -2, -1, 0),
124 (3, -2, -1, 1),
125 (4, -2, -2, 0),
126 (5, -2, -2, 1),
127 (6, -2, -3, 0),
128 (7, -2, -3, 1),
129 (-8, -1, -8, 0), # overflows and wraps around
130 (-7, -1, 7, 0),
131 (-6, -1, 6, 0),
132 (-5, -1, 5, 0),
133 (-4, -1, 4, 0),
134 (-3, -1, 3, 0),
135 (-2, -1, 2, 0),
136 (-1, -1, 1, 0),
137 (0, -1, 0, 0),
138 (1, -1, -1, 0),
139 (2, -1, -2, 0),
140 (3, -1, -3, 0),
141 (4, -1, -4, 0),
142 (5, -1, -5, 0),
143 (6, -1, -6, 0),
144 (7, -1, -7, 0),
145 (-8, 0, -1, -8),
146 (-7, 0, -1, -7),
147 (-6, 0, -1, -6),
148 (-5, 0, -1, -5),
149 (-4, 0, -1, -4),
150 (-3, 0, -1, -3),
151 (-2, 0, -1, -2),
152 (-1, 0, -1, -1),
153 (0, 0, -1, 0),
154 (1, 0, -1, 1),
155 (2, 0, -1, 2),
156 (3, 0, -1, 3),
157 (4, 0, -1, 4),
158 (5, 0, -1, 5),
159 (6, 0, -1, 6),
160 (7, 0, -1, 7),
161 (-8, 1, -8, 0),
162 (-7, 1, -7, 0),
163 (-6, 1, -6, 0),
164 (-5, 1, -5, 0),
165 (-4, 1, -4, 0),
166 (-3, 1, -3, 0),
167 (-2, 1, -2, 0),
168 (-1, 1, -1, 0),
169 (0, 1, 0, 0),
170 (1, 1, 1, 0),
171 (2, 1, 2, 0),
172 (3, 1, 3, 0),
173 (4, 1, 4, 0),
174 (5, 1, 5, 0),
175 (6, 1, 6, 0),
176 (7, 1, 7, 0),
177 (-8, 2, -4, 0),
178 (-7, 2, -3, -1),
179 (-6, 2, -3, 0),
180 (-5, 2, -2, -1),
181 (-4, 2, -2, 0),
182 (-3, 2, -1, -1),
183 (-2, 2, -1, 0),
184 (-1, 2, 0, -1),
185 (0, 2, 0, 0),
186 (1, 2, 0, 1),
187 (2, 2, 1, 0),
188 (3, 2, 1, 1),
189 (4, 2, 2, 0),
190 (5, 2, 2, 1),
191 (6, 2, 3, 0),
192 (7, 2, 3, 1),
193 (-8, 3, -2, -2),
194 (-7, 3, -2, -1),
195 (-6, 3, -2, 0),
196 (-5, 3, -1, -2),
197 (-4, 3, -1, -1),
198 (-3, 3, -1, 0),
199 (-2, 3, 0, -2),
200 (-1, 3, 0, -1),
201 (0, 3, 0, 0),
202 (1, 3, 0, 1),
203 (2, 3, 0, 2),
204 (3, 3, 1, 0),
205 (4, 3, 1, 1),
206 (5, 3, 1, 2),
207 (6, 3, 2, 0),
208 (7, 3, 2, 1),
209 (-8, 4, -2, 0),
210 (-7, 4, -1, -3),
211 (-6, 4, -1, -2),
212 (-5, 4, -1, -1),
213 (-4, 4, -1, 0),
214 (-3, 4, 0, -3),
215 (-2, 4, 0, -2),
216 (-1, 4, 0, -1),
217 (0, 4, 0, 0),
218 (1, 4, 0, 1),
219 (2, 4, 0, 2),
220 (3, 4, 0, 3),
221 (4, 4, 1, 0),
222 (5, 4, 1, 1),
223 (6, 4, 1, 2),
224 (7, 4, 1, 3),
225 (-8, 5, -1, -3),
226 (-7, 5, -1, -2),
227 (-6, 5, -1, -1),
228 (-5, 5, -1, 0),
229 (-4, 5, 0, -4),
230 (-3, 5, 0, -3),
231 (-2, 5, 0, -2),
232 (-1, 5, 0, -1),
233 (0, 5, 0, 0),
234 (1, 5, 0, 1),
235 (2, 5, 0, 2),
236 (3, 5, 0, 3),
237 (4, 5, 0, 4),
238 (5, 5, 1, 0),
239 (6, 5, 1, 1),
240 (7, 5, 1, 2),
241 (-8, 6, -1, -2),
242 (-7, 6, -1, -1),
243 (-6, 6, -1, 0),
244 (-5, 6, 0, -5),
245 (-4, 6, 0, -4),
246 (-3, 6, 0, -3),
247 (-2, 6, 0, -2),
248 (-1, 6, 0, -1),
249 (0, 6, 0, 0),
250 (1, 6, 0, 1),
251 (2, 6, 0, 2),
252 (3, 6, 0, 3),
253 (4, 6, 0, 4),
254 (5, 6, 0, 5),
255 (6, 6, 1, 0),
256 (7, 6, 1, 1),
257 (-8, 7, -1, -1),
258 (-7, 7, -1, 0),
259 (-6, 7, 0, -6),
260 (-5, 7, 0, -5),
261 (-4, 7, 0, -4),
262 (-3, 7, 0, -3),
263 (-2, 7, 0, -2),
264 (-1, 7, 0, -1),
265 (0, 7, 0, 0),
266 (1, 7, 0, 1),
267 (2, 7, 0, 2),
268 (3, 7, 0, 3),
269 (4, 7, 0, 4),
270 (5, 7, 0, 5),
271 (6, 7, 0, 6),
272 (7, 7, 1, 0),
273 ]
274 for (n, d, q, r) in test_cases:
275 self.assertEqual(div_rem(n, d, 4, True), (q, r))
276
277 def test_unsigned(self):
278 for n in range(16):
279 for d in range(16):
280 if d == 0:
281 q = 16 - 1
282 r = n
283 else:
284 # div_rem matches // and % for unsigned integers
285 q = n // d
286 r = n % d
287 self.assertEqual(div_rem(n, d, 4, False), (q, r))
288
289
290 class TestUnsignedDivRem(unittest.TestCase):
291 def helper(self, log2_radix):
292 bit_width = 4
293 for n in range(1 << bit_width):
294 for d in range(1 << bit_width):
295 q, r = div_rem(n, d, bit_width, False)
296 with self.subTest(n=n, d=d, q=q, r=r):
297 udr = UnsignedDivRem(n, d, bit_width, log2_radix)
298 for _ in range(250 * bit_width):
299 self.assertEqual(n, udr.quotient * udr.divisor
300 + udr.remainder)
301 if udr.calculate_stage():
302 break
303 else:
304 self.fail("infinite loop")
305 self.assertEqual(n, udr.quotient * udr.divisor
306 + udr.remainder)
307 self.assertEqual(udr.quotient, q)
308 self.assertEqual(udr.remainder, r)
309
310 def test_radix_2(self):
311 self.helper(1)
312
313 def test_radix_4(self):
314 self.helper(2)
315
316 def test_radix_8(self):
317 self.helper(3)
318
319 def test_radix_16(self):
320 self.helper(4)
321
322
323 class TestDivRem(unittest.TestCase):
324 def helper(self, log2_radix):
325 bit_width = 4
326 for n in range(1 << bit_width):
327 for d in range(1 << bit_width):
328 for signed in False, True:
329 n = Const.normalize(n, (bit_width, signed))
330 d = Const.normalize(d, (bit_width, signed))
331 q, r = div_rem(n, d, bit_width, signed)
332 with self.subTest(n=n, d=d, q=q, r=r, signed=signed):
333 dr = DivRem(n, d, bit_width, signed, log2_radix)
334 for _ in range(250 * bit_width):
335 if dr.calculate_stage():
336 break
337 else:
338 self.fail("infinite loop")
339 self.assertEqual(dr.quotient, q)
340 self.assertEqual(dr.remainder, r)
341
342 def test_radix_2(self):
343 self.helper(1)
344
345 def test_radix_4(self):
346 self.helper(2)
347
348 def test_radix_8(self):
349 self.helper(3)
350
351 def test_radix_16(self):
352 self.helper(4)
353
354
355 class TestFixed(unittest.TestCase):
356 def test_constructor(self):
357 value = Fixed(0, 0, 1, False)
358 self.assertEqual(value.bits, 0)
359 self.assertEqual(value.fract_width, 0)
360 self.assertEqual(value.bit_width, 1)
361 self.assertEqual(value.signed, False)
362 value = Fixed(1, 2, 3, True)
363 self.assertEqual(value.bits, -4)
364 self.assertEqual(value.fract_width, 2)
365 self.assertEqual(value.bit_width, 3)
366 self.assertEqual(value.signed, True)
367 value = Fixed(1, 2, 4, True)
368 self.assertEqual(value.bits, 4)
369 self.assertEqual(value.fract_width, 2)
370 self.assertEqual(value.bit_width, 4)
371 self.assertEqual(value.signed, True)
372 value = Fixed(1.25, 4, 8, True)
373 self.assertEqual(value.bits, 0x14)
374 self.assertEqual(value.fract_width, 4)
375 self.assertEqual(value.bit_width, 8)
376 self.assertEqual(value.signed, True)
377 value = Fixed(Fixed(2, 0, 12, False), 4, 8, True)
378 self.assertEqual(value.bits, 0x20)
379 self.assertEqual(value.fract_width, 4)
380 self.assertEqual(value.bit_width, 8)
381 self.assertEqual(value.signed, True)
382 value = Fixed(0x2FF / 2 ** 8, 8, 12, False)
383 self.assertEqual(value.bits, 0x2FF)
384 self.assertEqual(value.fract_width, 8)
385 self.assertEqual(value.bit_width, 12)
386 self.assertEqual(value.signed, False)
387 value = Fixed(value, 4, 8, True)
388 self.assertEqual(value.bits, 0x2F)
389 self.assertEqual(value.fract_width, 4)
390 self.assertEqual(value.bit_width, 8)
391 self.assertEqual(value.signed, True)
392
393 def helper_tst_from_bits(self, bit_width, fract_width):
394 signed = False
395 for bits in range(1 << bit_width):
396 with self.subTest(bit_width=bit_width,
397 fract_width=fract_width,
398 signed=signed,
399 bits=hex(bits)):
400 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
401 self.assertEqual(value.bit_width, bit_width)
402 self.assertEqual(value.fract_width, fract_width)
403 self.assertEqual(value.signed, signed)
404 self.assertEqual(value.bits, bits)
405 signed = True
406 for bits in range(-1 << (bit_width - 1), 1 << (bit_width - 1)):
407 with self.subTest(bit_width=bit_width,
408 fract_width=fract_width,
409 signed=signed,
410 bits=hex(bits)):
411 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
412 self.assertEqual(value.bit_width, bit_width)
413 self.assertEqual(value.fract_width, fract_width)
414 self.assertEqual(value.signed, signed)
415 self.assertEqual(value.bits, bits)
416
417 def test_from_bits(self):
418 for bit_width in range(1, 5):
419 for fract_width in range(bit_width):
420 self.helper_tst_from_bits(bit_width, fract_width)
421
422 def test_repr(self):
423 self.assertEqual(repr(Fixed.from_bits(1, 2, 3, False)),
424 "Fixed.from_bits(1, 2, 3, False)")
425 self.assertEqual(repr(Fixed.from_bits(-4, 2, 3, True)),
426 "Fixed.from_bits(-4, 2, 3, True)")
427 self.assertEqual(repr(Fixed.from_bits(-4, 7, 10, True)),
428 "Fixed.from_bits(-4, 7, 10, True)")
429
430 def test_trunc(self):
431 for i in range(-8, 8):
432 value = Fixed.from_bits(i, 2, 4, True)
433 with self.subTest(value=repr(value)):
434 self.assertEqual(math.trunc(value), math.trunc(i / 4))
435
436 def test_int(self):
437 for i in range(-8, 8):
438 value = Fixed.from_bits(i, 2, 4, True)
439 with self.subTest(value=repr(value)):
440 self.assertEqual(int(value), math.trunc(value))
441
442 def test_float(self):
443 for i in range(-8, 8):
444 value = Fixed.from_bits(i, 2, 4, True)
445 with self.subTest(value=repr(value)):
446 self.assertEqual(float(value), i / 4)
447
448 def test_floor(self):
449 for i in range(-8, 8):
450 value = Fixed.from_bits(i, 2, 4, True)
451 with self.subTest(value=repr(value)):
452 self.assertEqual(math.floor(value), math.floor(i / 4))
453
454 def test_ceil(self):
455 for i in range(-8, 8):
456 value = Fixed.from_bits(i, 2, 4, True)
457 with self.subTest(value=repr(value)):
458 self.assertEqual(math.ceil(value), math.ceil(i / 4))
459
460 def test_neg(self):
461 for i in range(-8, 8):
462 value = Fixed.from_bits(i, 2, 4, True)
463 expected = -i / 4 if i != -8 else -2.0 # handle wrap-around
464 with self.subTest(value=repr(value)):
465 self.assertEqual(float(-value), expected)
466
467 def test_pos(self):
468 for i in range(-8, 8):
469 value = Fixed.from_bits(i, 2, 4, True)
470 with self.subTest(value=repr(value)):
471 value = +value
472 self.assertEqual(value.bits, i)
473
474 def test_abs(self):
475 for i in range(-8, 8):
476 value = Fixed.from_bits(i, 2, 4, True)
477 expected = abs(i) / 4 if i != -8 else -2.0 # handle wrap-around
478 with self.subTest(value=repr(value)):
479 self.assertEqual(float(abs(value)), expected)
480
481 def test_not(self):
482 for i in range(-8, 8):
483 value = Fixed.from_bits(i, 2, 4, True)
484 with self.subTest(value=repr(value)):
485 self.assertEqual(float(~value), (~i) / 4)
486
487 @staticmethod
488 def get_test_values(max_bit_width, include_int):
489 for signed in False, True:
490 if include_int:
491 for bits in range(1 << max_bit_width):
492 int_value = Const.normalize(bits, (max_bit_width, signed))
493 yield int_value
494 for bit_width in range(1, max_bit_width):
495 for fract_width in range(bit_width + 1):
496 for bits in range(1 << bit_width):
497 yield Fixed.from_bits(bits,
498 fract_width,
499 bit_width,
500 signed)
501
502 def binary_op_test_helper(self,
503 operation,
504 is_fixed=True,
505 width_combine_op=max,
506 adjust_bits_op=None):
507 def default_adjust_bits_op(bits, out_fract_width, in_fract_width):
508 return bits << (out_fract_width - in_fract_width)
509 if adjust_bits_op is None:
510 adjust_bits_op = default_adjust_bits_op
511 max_bit_width = 5
512 for lhs in self.get_test_values(max_bit_width, True):
513 lhs_is_int = isinstance(lhs, int)
514 for rhs in self.get_test_values(max_bit_width, not lhs_is_int):
515 rhs_is_int = isinstance(rhs, int)
516 if lhs_is_int:
517 assert not rhs_is_int
518 lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0)
519 int_result = operation(lhs_int, rhs.bits)
520 if is_fixed:
521 expected = Fixed.from_bits(int_result,
522 rhs.fract_width,
523 rhs.bit_width,
524 rhs.signed)
525 else:
526 expected = int_result
527 elif rhs_is_int:
528 rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0)
529 int_result = operation(lhs.bits, rhs_int)
530 if is_fixed:
531 expected = Fixed.from_bits(int_result,
532 lhs.fract_width,
533 lhs.bit_width,
534 lhs.signed)
535 else:
536 expected = int_result
537 elif lhs.signed != rhs.signed:
538 continue
539 else:
540 fract_width = width_combine_op(lhs.fract_width,
541 rhs.fract_width)
542 int_width = width_combine_op(lhs.bit_width
543 - lhs.fract_width,
544 rhs.bit_width
545 - rhs.fract_width)
546 bit_width = fract_width + int_width
547 lhs_int = adjust_bits_op(lhs.bits,
548 fract_width,
549 lhs.fract_width)
550 rhs_int = adjust_bits_op(rhs.bits,
551 fract_width,
552 rhs.fract_width)
553 int_result = operation(lhs_int, rhs_int)
554 if is_fixed:
555 expected = Fixed.from_bits(int_result,
556 fract_width,
557 bit_width,
558 lhs.signed)
559 else:
560 expected = int_result
561 with self.subTest(lhs=repr(lhs),
562 rhs=repr(rhs),
563 expected=repr(expected)):
564 result = operation(lhs, rhs)
565 if is_fixed:
566 self.assertEqual(result.bit_width, expected.bit_width)
567 self.assertEqual(result.signed, expected.signed)
568 self.assertEqual(result.fract_width,
569 expected.fract_width)
570 self.assertEqual(result.bits, expected.bits)
571 else:
572 self.assertEqual(result, expected)
573
574 def test_add(self):
575 self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs)
576
577 def test_sub(self):
578 self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs)
579
580 def test_and(self):
581 self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs)
582
583 def test_or(self):
584 self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs)
585
586 def test_xor(self):
587 self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs)
588
589 def test_mul(self):
590 def adjust_bits_op(bits, out_fract_width, in_fract_width):
591 return bits
592 self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs,
593 True,
594 lambda l_width, r_width: l_width + r_width,
595 adjust_bits_op)
596
597 def test_cmp(self):
598 def cmp(lhs, rhs):
599 if lhs < rhs:
600 return -1
601 elif lhs > rhs:
602 return 1
603 return 0
604 self.binary_op_test_helper(cmp, False)
605
606 def test_lt(self):
607 self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False)
608
609 def test_le(self):
610 self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False)
611
612 def test_eq(self):
613 self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False)
614
615 def test_ne(self):
616 self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False)
617
618 def test_gt(self):
619 self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False)
620
621 def test_ge(self):
622 self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False)
623
624 def test_bool(self):
625 for v in self.get_test_values(6, False):
626 with self.subTest(v=repr(v)):
627 self.assertEqual(bool(v), bool(v.bits))
628
629 def test_str(self):
630 self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)),
631 "fixed:0x1234.")
632 self.assertEqual(str(Fixed.from_bits(-0x1234, 0, 16, True)),
633 "fixed:-0x1234.")
634 self.assertEqual(str(Fixed.from_bits(0x12345, 3, 20, True)),
635 "fixed:0x2468.a")
636 self.assertEqual(str(Fixed(123.625, 3, 12, True)),
637 "fixed:0x7b.a")
638
639 self.assertEqual(str(Fixed.from_bits(0x1, 0, 20, True)),
640 "fixed:0x1.")
641 self.assertEqual(str(Fixed.from_bits(0x2, 1, 20, True)),
642 "fixed:0x1.0")
643 self.assertEqual(str(Fixed.from_bits(0x4, 2, 20, True)),
644 "fixed:0x1.0")
645 self.assertEqual(str(Fixed.from_bits(0x9, 3, 20, True)),
646 "fixed:0x1.2")
647 self.assertEqual(str(Fixed.from_bits(0x12, 4, 20, True)),
648 "fixed:0x1.2")
649 self.assertEqual(str(Fixed.from_bits(0x24, 5, 20, True)),
650 "fixed:0x1.20")
651 self.assertEqual(str(Fixed.from_bits(0x48, 6, 20, True)),
652 "fixed:0x1.20")
653 self.assertEqual(str(Fixed.from_bits(0x91, 7, 20, True)),
654 "fixed:0x1.22")
655 self.assertEqual(str(Fixed.from_bits(0x123, 8, 20, True)),
656 "fixed:0x1.23")
657 self.assertEqual(str(Fixed.from_bits(0x246, 9, 20, True)),
658 "fixed:0x1.230")
659 self.assertEqual(str(Fixed.from_bits(0x48d, 10, 20, True)),
660 "fixed:0x1.234")
661 self.assertEqual(str(Fixed.from_bits(0x91a, 11, 20, True)),
662 "fixed:0x1.234")
663 self.assertEqual(str(Fixed.from_bits(0x1234, 12, 20, True)),
664 "fixed:0x1.234")
665 self.assertEqual(str(Fixed.from_bits(0x2468, 13, 20, True)),
666 "fixed:0x1.2340")
667 self.assertEqual(str(Fixed.from_bits(0x48d1, 14, 20, True)),
668 "fixed:0x1.2344")
669 self.assertEqual(str(Fixed.from_bits(0x91a2, 15, 20, True)),
670 "fixed:0x1.2344")
671 self.assertEqual(str(Fixed.from_bits(0x12345, 16, 20, True)),
672 "fixed:0x1.2345")
673 self.assertEqual(str(Fixed.from_bits(0x2468a, 17, 20, True)),
674 "fixed:0x1.23450")
675 self.assertEqual(str(Fixed.from_bits(0x48d14, 18, 20, True)),
676 "fixed:0x1.23450")
677 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, True)),
678 "fixed:-0x0.dcbb0")
679 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, False)),
680 "fixed:0x1.23450")
681
682
683 class TestFixedSqrtFn(unittest.TestCase):
684 def test_on_ints(self):
685 for radicand in range(-1, 32):
686 if radicand < 0:
687 expected = None
688 else:
689 root = math.floor(math.sqrt(radicand))
690 remainder = radicand - root * root
691 expected = RootRemainder(root, remainder)
692 with self.subTest(radicand=radicand, expected=expected):
693 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
694 radicand = 2 << 64
695 root = 0x16A09E667
696 remainder = radicand - root * root
697 expected = RootRemainder(root, remainder)
698 with self.subTest(radicand=radicand, expected=expected):
699 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
700
701 def test_on_fixed(self):
702 for signed in False, True:
703 for bit_width in range(1, 10):
704 for fract_width in range(bit_width):
705 for bits in range(1 << bit_width):
706 radicand = Fixed.from_bits(bits,
707 fract_width,
708 bit_width,
709 signed)
710 if radicand < 0:
711 continue
712 root = radicand.with_value(math.sqrt(float(radicand)))
713 remainder = radicand - root * root
714 expected = RootRemainder(root, remainder)
715 with self.subTest(radicand=repr(radicand),
716 expected=repr(expected)):
717 self.assertEqual(repr(fixed_sqrt(radicand)),
718 repr(expected))
719
720 def test_misc_cases(self):
721 test_cases = [
722 # radicand, expected
723 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
724 (Fixed(2, 30, 32, False),
725 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
726 ]
727 for radicand, expected in test_cases:
728 with self.subTest(radicand=str(radicand), expected=expected):
729 self.assertEqual(str(fixed_sqrt(radicand)), expected)
730
731
732 class TestFixedSqrt(unittest.TestCase):
733 def helper(self, log2_radix):
734 for bit_width in range(1, 8):
735 for fract_width in range(bit_width):
736 for radicand_bits in range(1 << bit_width):
737 radicand = Fixed.from_bits(radicand_bits,
738 fract_width,
739 bit_width,
740 False)
741 root_remainder = fixed_sqrt(radicand)
742 with self.subTest(radicand=repr(radicand),
743 root_remainder=repr(root_remainder),
744 log2_radix=log2_radix):
745 obj = FixedSqrt(radicand, log2_radix)
746 for _ in range(250 * bit_width):
747 self.assertEqual(obj.root * obj.root,
748 obj.root_squared)
749 self.assertGreaterEqual(obj.radicand,
750 obj.root_squared)
751 if obj.calculate_stage():
752 break
753 else:
754 self.fail("infinite loop")
755 self.assertEqual(obj.root * obj.root,
756 obj.root_squared)
757 self.assertGreaterEqual(obj.radicand,
758 obj.root_squared)
759 self.assertEqual(obj.remainder,
760 obj.radicand - obj.root_squared)
761 self.assertEqual(obj.root, root_remainder.root)
762 self.assertEqual(obj.remainder,
763 root_remainder.remainder)
764
765 def test_radix_2(self):
766 self.helper(1)
767
768 def test_radix_4(self):
769 self.helper(2)
770
771 def test_radix_8(self):
772 self.helper(3)
773
774 def test_radix_16(self):
775 self.helper(4)
776
777
778 class TestFixedRSqrtFn(unittest.TestCase):
779 def test2(self):
780 for bits in range(1, 1 << 5):
781 radicand = Fixed.from_bits(bits, 5, 12, False)
782 float_root = 1 / math.sqrt(float(radicand))
783 root = radicand.with_value(float_root)
784 remainder = 1 - root * root * radicand
785 expected = RootRemainder(root, remainder)
786 with self.subTest(radicand=repr(radicand),
787 expected=repr(expected)):
788 self.assertEqual(repr(fixed_rsqrt(radicand)),
789 repr(expected))
790
791 def test(self):
792 for signed in False, True:
793 for bit_width in range(1, 10):
794 for fract_width in range(bit_width):
795 for bits in range(1 << bit_width):
796 radicand = Fixed.from_bits(bits,
797 fract_width,
798 bit_width,
799 signed)
800 if radicand <= 0:
801 continue
802 float_root = 1 / math.sqrt(float(radicand))
803 max_value = radicand.with_bits(
804 (1 << (bit_width - signed)) - 1)
805 if float_root > float(max_value):
806 root = max_value
807 else:
808 root = radicand.with_value(float_root)
809 remainder = 1 - root * root * radicand
810 expected = RootRemainder(root, remainder)
811 with self.subTest(radicand=repr(radicand),
812 expected=repr(expected)):
813 self.assertEqual(repr(fixed_rsqrt(radicand)),
814 repr(expected))
815
816 def test_misc_cases(self):
817 test_cases = [
818 # radicand, expected
819 (Fixed(0.5, 30, 32, False),
820 "RootRemainder(fixed:0x1.6a09e664, "
821 "fixed:0x0.0000000596d014780000000)")
822 ]
823 for radicand, expected in test_cases:
824 with self.subTest(radicand=str(radicand), expected=expected):
825 self.assertEqual(str(fixed_rsqrt(radicand)), expected)
826
827
828 class TestFixedRSqrt(unittest.TestCase):
829 def helper(self, log2_radix):
830 for bit_width in range(1, 8):
831 for fract_width in range(bit_width):
832 for radicand_bits in range(1, 1 << bit_width):
833 radicand = Fixed.from_bits(radicand_bits,
834 fract_width,
835 bit_width,
836 False)
837 root_remainder = fixed_rsqrt(radicand)
838 with self.subTest(radicand=repr(radicand),
839 root_remainder=repr(root_remainder),
840 log2_radix=log2_radix):
841 obj = FixedRSqrt(radicand, log2_radix)
842 for _ in range(250 * bit_width):
843 self.assertEqual(obj.radicand * obj.root,
844 obj.radicand_root)
845 self.assertEqual(obj.radicand_root * obj.root,
846 obj.radicand_root_squared)
847 self.assertGreaterEqual(1,
848 obj.radicand_root_squared)
849 if obj.calculate_stage():
850 break
851 else:
852 self.fail("infinite loop")
853 self.assertEqual(obj.radicand * obj.root,
854 obj.radicand_root)
855 self.assertEqual(obj.radicand_root * obj.root,
856 obj.radicand_root_squared)
857 self.assertGreaterEqual(1,
858 obj.radicand_root_squared)
859 self.assertEqual(obj.remainder,
860 1 - obj.radicand_root_squared)
861 self.assertEqual(obj.root, root_remainder.root)
862 self.assertEqual(obj.remainder,
863 root_remainder.remainder)
864
865 def test_radix_2(self):
866 self.helper(1)
867
868 def test_radix_4(self):
869 self.helper(2)
870
871 def test_radix_8(self):
872 self.helper(3)
873
874 def test_radix_16(self):
875 self.helper(4)
876
877
878 class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
879 @staticmethod
880 def show_fixed(bits, fract_width, bit_width):
881 fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
882 return f"{str(fixed)}:{repr(fixed)}"
883
884 def check_invariants(self,
885 dividend,
886 divisor_radicand,
887 operation,
888 bit_width,
889 fract_width,
890 log2_radix,
891 obj):
892 self.assertEqual(obj.dividend, dividend)
893 self.assertEqual(obj.divisor_radicand, divisor_radicand)
894 self.assertEqual(obj.operation, operation)
895 self.assertEqual(obj.bit_width, bit_width)
896 self.assertEqual(obj.fract_width, fract_width)
897 self.assertEqual(obj.log2_radix, log2_radix)
898 self.assertEqual(obj.root_times_radicand,
899 obj.quotient_root * obj.divisor_radicand)
900 self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
901 self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
902 if operation is Operation.UDivRem:
903 self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
904 self.assertEqual(obj.compare_rhs,
905 (obj.quotient_root * obj.divisor_radicand)
906 << fract_width)
907 elif operation is Operation.SqrtRem:
908 self.assertEqual(obj.compare_lhs,
909 obj.divisor_radicand << (fract_width * 2))
910 self.assertEqual(obj.compare_rhs,
911 (obj.quotient_root * obj.quotient_root)
912 << fract_width)
913 else:
914 assert operation is Operation.RSqrtRem
915 self.assertEqual(obj.compare_lhs,
916 1 << (fract_width * 3))
917 self.assertEqual(obj.compare_rhs,
918 obj.quotient_root * obj.quotient_root
919 * obj.divisor_radicand)
920
921 def handle_case(self,
922 dividend,
923 divisor_radicand,
924 operation,
925 bit_width,
926 fract_width,
927 log2_radix):
928 dividend_str = self.show_fixed(dividend,
929 fract_width * 2,
930 bit_width + fract_width)
931 divisor_radicand_str = self.show_fixed(divisor_radicand,
932 fract_width,
933 bit_width)
934 with self.subTest(dividend=dividend_str,
935 divisor_radicand=divisor_radicand_str,
936 operation=operation.name,
937 bit_width=bit_width,
938 fract_width=fract_width,
939 log2_radix=log2_radix):
940 if operation is Operation.UDivRem:
941 if divisor_radicand == 0:
942 return
943 quotient_root, remainder = div_rem(dividend,
944 divisor_radicand,
945 bit_width * 3,
946 False)
947 remainder <<= fract_width
948 elif operation is Operation.SqrtRem:
949 root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
950 fract_width,
951 bit_width,
952 False))
953 self.assertEqual(root_remainder.root.bit_width,
954 bit_width)
955 self.assertEqual(root_remainder.root.fract_width,
956 fract_width)
957 self.assertEqual(root_remainder.remainder.bit_width,
958 bit_width * 2)
959 self.assertEqual(root_remainder.remainder.fract_width,
960 fract_width * 2)
961 quotient_root = root_remainder.root.bits
962 remainder = root_remainder.remainder.bits << fract_width
963 else:
964 assert operation is Operation.RSqrtRem
965 if divisor_radicand == 0:
966 return
967 root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
968 fract_width,
969 bit_width,
970 False))
971 self.assertEqual(root_remainder.root.bit_width,
972 bit_width)
973 self.assertEqual(root_remainder.root.fract_width,
974 fract_width)
975 self.assertEqual(root_remainder.remainder.bit_width,
976 bit_width * 3)
977 self.assertEqual(root_remainder.remainder.fract_width,
978 fract_width * 3)
979 quotient_root = root_remainder.root.bits
980 remainder = root_remainder.remainder.bits
981 if quotient_root >= (1 << bit_width):
982 return
983 quotient_root_str = self.show_fixed(quotient_root,
984 fract_width,
985 bit_width)
986 remainder_str = self.show_fixed(remainder,
987 fract_width * 3,
988 bit_width * 3)
989 with self.subTest(quotient_root=quotient_root_str,
990 remainder=remainder_str):
991 obj = FixedUDivRemSqrtRSqrt(dividend,
992 divisor_radicand,
993 operation,
994 bit_width,
995 fract_width,
996 log2_radix)
997 for _ in range(250 * bit_width):
998 self.check_invariants(dividend,
999 divisor_radicand,
1000 operation,
1001 bit_width,
1002 fract_width,
1003 log2_radix,
1004 obj)
1005 if obj.calculate_stage():
1006 break
1007 else:
1008 self.fail("infinite loop")
1009 self.check_invariants(dividend,
1010 divisor_radicand,
1011 operation,
1012 bit_width,
1013 fract_width,
1014 log2_radix,
1015 obj)
1016 self.assertEqual(obj.quotient_root, quotient_root)
1017 self.assertEqual(obj.remainder, remainder)
1018
1019 def helper(self, log2_radix, operation):
1020 bit_width_range = range(1, 8)
1021 if operation is Operation.UDivRem:
1022 bit_width_range = range(1, 6)
1023 for bit_width in bit_width_range:
1024 for fract_width in range(bit_width):
1025 for divisor_radicand in range(1 << bit_width):
1026 dividend_range = range(1)
1027 if operation is Operation.UDivRem:
1028 dividend_range = range(1 << (bit_width + fract_width))
1029 for dividend in dividend_range:
1030 self.handle_case(dividend,
1031 divisor_radicand,
1032 operation,
1033 bit_width,
1034 fract_width,
1035 log2_radix)
1036
1037 def test_radix_2_UDiv(self):
1038 self.helper(1, Operation.UDivRem)
1039
1040 def test_radix_4_UDiv(self):
1041 self.helper(2, Operation.UDivRem)
1042
1043 def test_radix_8_UDiv(self):
1044 self.helper(3, Operation.UDivRem)
1045
1046 def test_radix_16_UDiv(self):
1047 self.helper(4, Operation.UDivRem)
1048
1049 def test_radix_2_Sqrt(self):
1050 self.helper(1, Operation.SqrtRem)
1051
1052 def test_radix_4_Sqrt(self):
1053 self.helper(2, Operation.SqrtRem)
1054
1055 def test_radix_8_Sqrt(self):
1056 self.helper(3, Operation.SqrtRem)
1057
1058 def test_radix_16_Sqrt(self):
1059 self.helper(4, Operation.SqrtRem)
1060
1061 def test_radix_2_RSqrt(self):
1062 self.helper(1, Operation.RSqrtRem)
1063
1064 def test_radix_4_RSqrt(self):
1065 self.helper(2, Operation.RSqrtRem)
1066
1067 def test_radix_8_RSqrt(self):
1068 self.helper(3, Operation.RSqrtRem)
1069
1070 def test_radix_16_RSqrt(self):
1071 self.helper(4, Operation.RSqrtRem)