implement FixedRSqrt
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 from nmigen.hdl.ast import Const
5 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
6 Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
7 fixed_rsqrt, FixedRSqrt)
8 import unittest
9 import math
10
11
12 class TestDivRemFn(unittest.TestCase):
13 def test_signed(self):
14 test_cases = [
15 # numerator, denominator, quotient, remainder
16 (-8, -8, 1, 0),
17 (-7, -8, 0, -7),
18 (-6, -8, 0, -6),
19 (-5, -8, 0, -5),
20 (-4, -8, 0, -4),
21 (-3, -8, 0, -3),
22 (-2, -8, 0, -2),
23 (-1, -8, 0, -1),
24 (0, -8, 0, 0),
25 (1, -8, 0, 1),
26 (2, -8, 0, 2),
27 (3, -8, 0, 3),
28 (4, -8, 0, 4),
29 (5, -8, 0, 5),
30 (6, -8, 0, 6),
31 (7, -8, 0, 7),
32 (-8, -7, 1, -1),
33 (-7, -7, 1, 0),
34 (-6, -7, 0, -6),
35 (-5, -7, 0, -5),
36 (-4, -7, 0, -4),
37 (-3, -7, 0, -3),
38 (-2, -7, 0, -2),
39 (-1, -7, 0, -1),
40 (0, -7, 0, 0),
41 (1, -7, 0, 1),
42 (2, -7, 0, 2),
43 (3, -7, 0, 3),
44 (4, -7, 0, 4),
45 (5, -7, 0, 5),
46 (6, -7, 0, 6),
47 (7, -7, -1, 0),
48 (-8, -6, 1, -2),
49 (-7, -6, 1, -1),
50 (-6, -6, 1, 0),
51 (-5, -6, 0, -5),
52 (-4, -6, 0, -4),
53 (-3, -6, 0, -3),
54 (-2, -6, 0, -2),
55 (-1, -6, 0, -1),
56 (0, -6, 0, 0),
57 (1, -6, 0, 1),
58 (2, -6, 0, 2),
59 (3, -6, 0, 3),
60 (4, -6, 0, 4),
61 (5, -6, 0, 5),
62 (6, -6, -1, 0),
63 (7, -6, -1, 1),
64 (-8, -5, 1, -3),
65 (-7, -5, 1, -2),
66 (-6, -5, 1, -1),
67 (-5, -5, 1, 0),
68 (-4, -5, 0, -4),
69 (-3, -5, 0, -3),
70 (-2, -5, 0, -2),
71 (-1, -5, 0, -1),
72 (0, -5, 0, 0),
73 (1, -5, 0, 1),
74 (2, -5, 0, 2),
75 (3, -5, 0, 3),
76 (4, -5, 0, 4),
77 (5, -5, -1, 0),
78 (6, -5, -1, 1),
79 (7, -5, -1, 2),
80 (-8, -4, 2, 0),
81 (-7, -4, 1, -3),
82 (-6, -4, 1, -2),
83 (-5, -4, 1, -1),
84 (-4, -4, 1, 0),
85 (-3, -4, 0, -3),
86 (-2, -4, 0, -2),
87 (-1, -4, 0, -1),
88 (0, -4, 0, 0),
89 (1, -4, 0, 1),
90 (2, -4, 0, 2),
91 (3, -4, 0, 3),
92 (4, -4, -1, 0),
93 (5, -4, -1, 1),
94 (6, -4, -1, 2),
95 (7, -4, -1, 3),
96 (-8, -3, 2, -2),
97 (-7, -3, 2, -1),
98 (-6, -3, 2, 0),
99 (-5, -3, 1, -2),
100 (-4, -3, 1, -1),
101 (-3, -3, 1, 0),
102 (-2, -3, 0, -2),
103 (-1, -3, 0, -1),
104 (0, -3, 0, 0),
105 (1, -3, 0, 1),
106 (2, -3, 0, 2),
107 (3, -3, -1, 0),
108 (4, -3, -1, 1),
109 (5, -3, -1, 2),
110 (6, -3, -2, 0),
111 (7, -3, -2, 1),
112 (-8, -2, 4, 0),
113 (-7, -2, 3, -1),
114 (-6, -2, 3, 0),
115 (-5, -2, 2, -1),
116 (-4, -2, 2, 0),
117 (-3, -2, 1, -1),
118 (-2, -2, 1, 0),
119 (-1, -2, 0, -1),
120 (0, -2, 0, 0),
121 (1, -2, 0, 1),
122 (2, -2, -1, 0),
123 (3, -2, -1, 1),
124 (4, -2, -2, 0),
125 (5, -2, -2, 1),
126 (6, -2, -3, 0),
127 (7, -2, -3, 1),
128 (-8, -1, -8, 0), # overflows and wraps around
129 (-7, -1, 7, 0),
130 (-6, -1, 6, 0),
131 (-5, -1, 5, 0),
132 (-4, -1, 4, 0),
133 (-3, -1, 3, 0),
134 (-2, -1, 2, 0),
135 (-1, -1, 1, 0),
136 (0, -1, 0, 0),
137 (1, -1, -1, 0),
138 (2, -1, -2, 0),
139 (3, -1, -3, 0),
140 (4, -1, -4, 0),
141 (5, -1, -5, 0),
142 (6, -1, -6, 0),
143 (7, -1, -7, 0),
144 (-8, 0, -1, -8),
145 (-7, 0, -1, -7),
146 (-6, 0, -1, -6),
147 (-5, 0, -1, -5),
148 (-4, 0, -1, -4),
149 (-3, 0, -1, -3),
150 (-2, 0, -1, -2),
151 (-1, 0, -1, -1),
152 (0, 0, -1, 0),
153 (1, 0, -1, 1),
154 (2, 0, -1, 2),
155 (3, 0, -1, 3),
156 (4, 0, -1, 4),
157 (5, 0, -1, 5),
158 (6, 0, -1, 6),
159 (7, 0, -1, 7),
160 (-8, 1, -8, 0),
161 (-7, 1, -7, 0),
162 (-6, 1, -6, 0),
163 (-5, 1, -5, 0),
164 (-4, 1, -4, 0),
165 (-3, 1, -3, 0),
166 (-2, 1, -2, 0),
167 (-1, 1, -1, 0),
168 (0, 1, 0, 0),
169 (1, 1, 1, 0),
170 (2, 1, 2, 0),
171 (3, 1, 3, 0),
172 (4, 1, 4, 0),
173 (5, 1, 5, 0),
174 (6, 1, 6, 0),
175 (7, 1, 7, 0),
176 (-8, 2, -4, 0),
177 (-7, 2, -3, -1),
178 (-6, 2, -3, 0),
179 (-5, 2, -2, -1),
180 (-4, 2, -2, 0),
181 (-3, 2, -1, -1),
182 (-2, 2, -1, 0),
183 (-1, 2, 0, -1),
184 (0, 2, 0, 0),
185 (1, 2, 0, 1),
186 (2, 2, 1, 0),
187 (3, 2, 1, 1),
188 (4, 2, 2, 0),
189 (5, 2, 2, 1),
190 (6, 2, 3, 0),
191 (7, 2, 3, 1),
192 (-8, 3, -2, -2),
193 (-7, 3, -2, -1),
194 (-6, 3, -2, 0),
195 (-5, 3, -1, -2),
196 (-4, 3, -1, -1),
197 (-3, 3, -1, 0),
198 (-2, 3, 0, -2),
199 (-1, 3, 0, -1),
200 (0, 3, 0, 0),
201 (1, 3, 0, 1),
202 (2, 3, 0, 2),
203 (3, 3, 1, 0),
204 (4, 3, 1, 1),
205 (5, 3, 1, 2),
206 (6, 3, 2, 0),
207 (7, 3, 2, 1),
208 (-8, 4, -2, 0),
209 (-7, 4, -1, -3),
210 (-6, 4, -1, -2),
211 (-5, 4, -1, -1),
212 (-4, 4, -1, 0),
213 (-3, 4, 0, -3),
214 (-2, 4, 0, -2),
215 (-1, 4, 0, -1),
216 (0, 4, 0, 0),
217 (1, 4, 0, 1),
218 (2, 4, 0, 2),
219 (3, 4, 0, 3),
220 (4, 4, 1, 0),
221 (5, 4, 1, 1),
222 (6, 4, 1, 2),
223 (7, 4, 1, 3),
224 (-8, 5, -1, -3),
225 (-7, 5, -1, -2),
226 (-6, 5, -1, -1),
227 (-5, 5, -1, 0),
228 (-4, 5, 0, -4),
229 (-3, 5, 0, -3),
230 (-2, 5, 0, -2),
231 (-1, 5, 0, -1),
232 (0, 5, 0, 0),
233 (1, 5, 0, 1),
234 (2, 5, 0, 2),
235 (3, 5, 0, 3),
236 (4, 5, 0, 4),
237 (5, 5, 1, 0),
238 (6, 5, 1, 1),
239 (7, 5, 1, 2),
240 (-8, 6, -1, -2),
241 (-7, 6, -1, -1),
242 (-6, 6, -1, 0),
243 (-5, 6, 0, -5),
244 (-4, 6, 0, -4),
245 (-3, 6, 0, -3),
246 (-2, 6, 0, -2),
247 (-1, 6, 0, -1),
248 (0, 6, 0, 0),
249 (1, 6, 0, 1),
250 (2, 6, 0, 2),
251 (3, 6, 0, 3),
252 (4, 6, 0, 4),
253 (5, 6, 0, 5),
254 (6, 6, 1, 0),
255 (7, 6, 1, 1),
256 (-8, 7, -1, -1),
257 (-7, 7, -1, 0),
258 (-6, 7, 0, -6),
259 (-5, 7, 0, -5),
260 (-4, 7, 0, -4),
261 (-3, 7, 0, -3),
262 (-2, 7, 0, -2),
263 (-1, 7, 0, -1),
264 (0, 7, 0, 0),
265 (1, 7, 0, 1),
266 (2, 7, 0, 2),
267 (3, 7, 0, 3),
268 (4, 7, 0, 4),
269 (5, 7, 0, 5),
270 (6, 7, 0, 6),
271 (7, 7, 1, 0),
272 ]
273 for (n, d, q, r) in test_cases:
274 self.assertEqual(div_rem(n, d, 4, True), (q, r))
275
276 def test_unsigned(self):
277 for n in range(16):
278 for d in range(16):
279 if d == 0:
280 q = 16 - 1
281 r = n
282 else:
283 # div_rem matches // and % for unsigned integers
284 q = n // d
285 r = n % d
286 self.assertEqual(div_rem(n, d, 4, False), (q, r))
287
288
289 class TestUnsignedDivRem(unittest.TestCase):
290 def helper(self, log2_radix):
291 bit_width = 4
292 for n in range(1 << bit_width):
293 for d in range(1 << bit_width):
294 q, r = div_rem(n, d, bit_width, False)
295 with self.subTest(n=n, d=d, q=q, r=r):
296 udr = UnsignedDivRem(n, d, bit_width, log2_radix)
297 for _ in range(250 * bit_width):
298 self.assertEqual(n, udr.quotient * udr.divisor
299 + udr.remainder)
300 if udr.calculate_stage():
301 break
302 else:
303 self.fail("infinite loop")
304 self.assertEqual(n, udr.quotient * udr.divisor
305 + udr.remainder)
306 self.assertEqual(udr.quotient, q)
307 self.assertEqual(udr.remainder, r)
308
309 def test_radix_2(self):
310 self.helper(1)
311
312 def test_radix_4(self):
313 self.helper(2)
314
315 def test_radix_8(self):
316 self.helper(3)
317
318 def test_radix_16(self):
319 self.helper(4)
320
321
322 class TestDivRem(unittest.TestCase):
323 def helper(self, log2_radix):
324 bit_width = 4
325 for n in range(1 << bit_width):
326 for d in range(1 << bit_width):
327 for signed in False, True:
328 n = Const.normalize(n, (bit_width, signed))
329 d = Const.normalize(d, (bit_width, signed))
330 q, r = div_rem(n, d, bit_width, signed)
331 with self.subTest(n=n, d=d, q=q, r=r, signed=signed):
332 dr = DivRem(n, d, bit_width, signed, log2_radix)
333 for _ in range(250 * bit_width):
334 if dr.calculate_stage():
335 break
336 else:
337 self.fail("infinite loop")
338 self.assertEqual(dr.quotient, q)
339 self.assertEqual(dr.remainder, r)
340
341 def test_radix_2(self):
342 self.helper(1)
343
344 def test_radix_4(self):
345 self.helper(2)
346
347 def test_radix_8(self):
348 self.helper(3)
349
350 def test_radix_16(self):
351 self.helper(4)
352
353
354 class TestFixed(unittest.TestCase):
355 def test_constructor(self):
356 value = Fixed(0, 0, 1, False)
357 self.assertEqual(value.bits, 0)
358 self.assertEqual(value.fract_width, 0)
359 self.assertEqual(value.bit_width, 1)
360 self.assertEqual(value.signed, False)
361 value = Fixed(1, 2, 3, True)
362 self.assertEqual(value.bits, -4)
363 self.assertEqual(value.fract_width, 2)
364 self.assertEqual(value.bit_width, 3)
365 self.assertEqual(value.signed, True)
366 value = Fixed(1, 2, 4, True)
367 self.assertEqual(value.bits, 4)
368 self.assertEqual(value.fract_width, 2)
369 self.assertEqual(value.bit_width, 4)
370 self.assertEqual(value.signed, True)
371 value = Fixed(1.25, 4, 8, True)
372 self.assertEqual(value.bits, 0x14)
373 self.assertEqual(value.fract_width, 4)
374 self.assertEqual(value.bit_width, 8)
375 self.assertEqual(value.signed, True)
376 value = Fixed(Fixed(2, 0, 12, False), 4, 8, True)
377 self.assertEqual(value.bits, 0x20)
378 self.assertEqual(value.fract_width, 4)
379 self.assertEqual(value.bit_width, 8)
380 self.assertEqual(value.signed, True)
381 value = Fixed(0x2FF / 2 ** 8, 8, 12, False)
382 self.assertEqual(value.bits, 0x2FF)
383 self.assertEqual(value.fract_width, 8)
384 self.assertEqual(value.bit_width, 12)
385 self.assertEqual(value.signed, False)
386 value = Fixed(value, 4, 8, True)
387 self.assertEqual(value.bits, 0x2F)
388 self.assertEqual(value.fract_width, 4)
389 self.assertEqual(value.bit_width, 8)
390 self.assertEqual(value.signed, True)
391
392 def helper_test_from_bits(self, bit_width, fract_width):
393 signed = False
394 for bits in range(1 << bit_width):
395 with self.subTest(bit_width=bit_width,
396 fract_width=fract_width,
397 signed=signed,
398 bits=hex(bits)):
399 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
400 self.assertEqual(value.bit_width, bit_width)
401 self.assertEqual(value.fract_width, fract_width)
402 self.assertEqual(value.signed, signed)
403 self.assertEqual(value.bits, bits)
404 signed = True
405 for bits in range(-1 << (bit_width - 1), 1 << (bit_width - 1)):
406 with self.subTest(bit_width=bit_width,
407 fract_width=fract_width,
408 signed=signed,
409 bits=hex(bits)):
410 value = Fixed.from_bits(bits, fract_width, bit_width, signed)
411 self.assertEqual(value.bit_width, bit_width)
412 self.assertEqual(value.fract_width, fract_width)
413 self.assertEqual(value.signed, signed)
414 self.assertEqual(value.bits, bits)
415
416 def test_from_bits(self):
417 for bit_width in range(1, 5):
418 for fract_width in range(bit_width):
419 self.helper_test_from_bits(bit_width, fract_width)
420
421 def test_repr(self):
422 self.assertEqual(repr(Fixed.from_bits(1, 2, 3, False)),
423 "Fixed.from_bits(1, 2, 3, False)")
424 self.assertEqual(repr(Fixed.from_bits(-4, 2, 3, True)),
425 "Fixed.from_bits(-4, 2, 3, True)")
426 self.assertEqual(repr(Fixed.from_bits(-4, 7, 10, True)),
427 "Fixed.from_bits(-4, 7, 10, True)")
428
429 def test_trunc(self):
430 for i in range(-8, 8):
431 value = Fixed.from_bits(i, 2, 4, True)
432 with self.subTest(value=repr(value)):
433 self.assertEqual(math.trunc(value), math.trunc(i / 4))
434
435 def test_int(self):
436 for i in range(-8, 8):
437 value = Fixed.from_bits(i, 2, 4, True)
438 with self.subTest(value=repr(value)):
439 self.assertEqual(int(value), math.trunc(value))
440
441 def test_float(self):
442 for i in range(-8, 8):
443 value = Fixed.from_bits(i, 2, 4, True)
444 with self.subTest(value=repr(value)):
445 self.assertEqual(float(value), i / 4)
446
447 def test_floor(self):
448 for i in range(-8, 8):
449 value = Fixed.from_bits(i, 2, 4, True)
450 with self.subTest(value=repr(value)):
451 self.assertEqual(math.floor(value), math.floor(i / 4))
452
453 def test_ceil(self):
454 for i in range(-8, 8):
455 value = Fixed.from_bits(i, 2, 4, True)
456 with self.subTest(value=repr(value)):
457 self.assertEqual(math.ceil(value), math.ceil(i / 4))
458
459 def test_neg(self):
460 for i in range(-8, 8):
461 value = Fixed.from_bits(i, 2, 4, True)
462 expected = -i / 4 if i != -8 else -2.0 # handle wrap-around
463 with self.subTest(value=repr(value)):
464 self.assertEqual(float(-value), expected)
465
466 def test_pos(self):
467 for i in range(-8, 8):
468 value = Fixed.from_bits(i, 2, 4, True)
469 with self.subTest(value=repr(value)):
470 value = +value
471 self.assertEqual(value.bits, i)
472
473 def test_abs(self):
474 for i in range(-8, 8):
475 value = Fixed.from_bits(i, 2, 4, True)
476 expected = abs(i) / 4 if i != -8 else -2.0 # handle wrap-around
477 with self.subTest(value=repr(value)):
478 self.assertEqual(float(abs(value)), expected)
479
480 def test_not(self):
481 for i in range(-8, 8):
482 value = Fixed.from_bits(i, 2, 4, True)
483 with self.subTest(value=repr(value)):
484 self.assertEqual(float(~value), (~i) / 4)
485
486 @staticmethod
487 def get_test_values(max_bit_width, include_int):
488 for signed in False, True:
489 if include_int:
490 for bits in range(1 << max_bit_width):
491 int_value = Const.normalize(bits, (max_bit_width, signed))
492 yield int_value
493 for bit_width in range(1, max_bit_width):
494 for fract_width in range(bit_width + 1):
495 for bits in range(1 << bit_width):
496 yield Fixed.from_bits(bits,
497 fract_width,
498 bit_width,
499 signed)
500
501 def binary_op_test_helper(self,
502 operation,
503 is_fixed=True,
504 width_combine_op=max,
505 adjust_bits_op=None):
506 def default_adjust_bits_op(bits, out_fract_width, in_fract_width):
507 return bits << (out_fract_width - in_fract_width)
508 if adjust_bits_op is None:
509 adjust_bits_op = default_adjust_bits_op
510 max_bit_width = 5
511 for lhs in self.get_test_values(max_bit_width, True):
512 lhs_is_int = isinstance(lhs, int)
513 for rhs in self.get_test_values(max_bit_width, not lhs_is_int):
514 rhs_is_int = isinstance(rhs, int)
515 if lhs_is_int:
516 assert not rhs_is_int
517 lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0)
518 int_result = operation(lhs_int, rhs.bits)
519 if is_fixed:
520 expected = Fixed.from_bits(int_result,
521 rhs.fract_width,
522 rhs.bit_width,
523 rhs.signed)
524 else:
525 expected = int_result
526 elif rhs_is_int:
527 rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0)
528 int_result = operation(lhs.bits, rhs_int)
529 if is_fixed:
530 expected = Fixed.from_bits(int_result,
531 lhs.fract_width,
532 lhs.bit_width,
533 lhs.signed)
534 else:
535 expected = int_result
536 elif lhs.signed != rhs.signed:
537 continue
538 else:
539 fract_width = width_combine_op(lhs.fract_width,
540 rhs.fract_width)
541 int_width = width_combine_op(lhs.bit_width
542 - lhs.fract_width,
543 rhs.bit_width
544 - rhs.fract_width)
545 bit_width = fract_width + int_width
546 lhs_int = adjust_bits_op(lhs.bits,
547 fract_width,
548 lhs.fract_width)
549 rhs_int = adjust_bits_op(rhs.bits,
550 fract_width,
551 rhs.fract_width)
552 int_result = operation(lhs_int, rhs_int)
553 if is_fixed:
554 expected = Fixed.from_bits(int_result,
555 fract_width,
556 bit_width,
557 lhs.signed)
558 else:
559 expected = int_result
560 with self.subTest(lhs=repr(lhs),
561 rhs=repr(rhs),
562 expected=repr(expected)):
563 result = operation(lhs, rhs)
564 if is_fixed:
565 self.assertEqual(result.bit_width, expected.bit_width)
566 self.assertEqual(result.signed, expected.signed)
567 self.assertEqual(result.fract_width,
568 expected.fract_width)
569 self.assertEqual(result.bits, expected.bits)
570 else:
571 self.assertEqual(result, expected)
572
573 def test_add(self):
574 self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs)
575
576 def test_sub(self):
577 self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs)
578
579 def test_and(self):
580 self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs)
581
582 def test_or(self):
583 self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs)
584
585 def test_xor(self):
586 self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs)
587
588 def test_mul(self):
589 def adjust_bits_op(bits, out_fract_width, in_fract_width):
590 return bits
591 self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs,
592 True,
593 lambda l_width, r_width: l_width + r_width,
594 adjust_bits_op)
595
596 def test_cmp(self):
597 def cmp(lhs, rhs):
598 if lhs < rhs:
599 return -1
600 elif lhs > rhs:
601 return 1
602 return 0
603 self.binary_op_test_helper(cmp, False)
604
605 def test_lt(self):
606 self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False)
607
608 def test_le(self):
609 self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False)
610
611 def test_eq(self):
612 self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False)
613
614 def test_ne(self):
615 self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False)
616
617 def test_gt(self):
618 self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False)
619
620 def test_ge(self):
621 self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False)
622
623 def test_bool(self):
624 for v in self.get_test_values(6, False):
625 with self.subTest(v=repr(v)):
626 self.assertEqual(bool(v), bool(v.bits))
627
628 def test_str(self):
629 self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)),
630 "fixed:0x1234.")
631 self.assertEqual(str(Fixed.from_bits(-0x1234, 0, 16, True)),
632 "fixed:-0x1234.")
633 self.assertEqual(str(Fixed.from_bits(0x12345, 3, 20, True)),
634 "fixed:0x2468.a")
635 self.assertEqual(str(Fixed(123.625, 3, 12, True)),
636 "fixed:0x7b.a")
637
638 self.assertEqual(str(Fixed.from_bits(0x1, 0, 20, True)),
639 "fixed:0x1.")
640 self.assertEqual(str(Fixed.from_bits(0x2, 1, 20, True)),
641 "fixed:0x1.0")
642 self.assertEqual(str(Fixed.from_bits(0x4, 2, 20, True)),
643 "fixed:0x1.0")
644 self.assertEqual(str(Fixed.from_bits(0x9, 3, 20, True)),
645 "fixed:0x1.2")
646 self.assertEqual(str(Fixed.from_bits(0x12, 4, 20, True)),
647 "fixed:0x1.2")
648 self.assertEqual(str(Fixed.from_bits(0x24, 5, 20, True)),
649 "fixed:0x1.20")
650 self.assertEqual(str(Fixed.from_bits(0x48, 6, 20, True)),
651 "fixed:0x1.20")
652 self.assertEqual(str(Fixed.from_bits(0x91, 7, 20, True)),
653 "fixed:0x1.22")
654 self.assertEqual(str(Fixed.from_bits(0x123, 8, 20, True)),
655 "fixed:0x1.23")
656 self.assertEqual(str(Fixed.from_bits(0x246, 9, 20, True)),
657 "fixed:0x1.230")
658 self.assertEqual(str(Fixed.from_bits(0x48d, 10, 20, True)),
659 "fixed:0x1.234")
660 self.assertEqual(str(Fixed.from_bits(0x91a, 11, 20, True)),
661 "fixed:0x1.234")
662 self.assertEqual(str(Fixed.from_bits(0x1234, 12, 20, True)),
663 "fixed:0x1.234")
664 self.assertEqual(str(Fixed.from_bits(0x2468, 13, 20, True)),
665 "fixed:0x1.2340")
666 self.assertEqual(str(Fixed.from_bits(0x48d1, 14, 20, True)),
667 "fixed:0x1.2344")
668 self.assertEqual(str(Fixed.from_bits(0x91a2, 15, 20, True)),
669 "fixed:0x1.2344")
670 self.assertEqual(str(Fixed.from_bits(0x12345, 16, 20, True)),
671 "fixed:0x1.2345")
672 self.assertEqual(str(Fixed.from_bits(0x2468a, 17, 20, True)),
673 "fixed:0x1.23450")
674 self.assertEqual(str(Fixed.from_bits(0x48d14, 18, 20, True)),
675 "fixed:0x1.23450")
676 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, True)),
677 "fixed:-0x0.dcbb0")
678 self.assertEqual(str(Fixed.from_bits(0x91a28, 19, 20, False)),
679 "fixed:0x1.23450")
680
681
682 class TestFixedSqrtFn(unittest.TestCase):
683 def test_on_ints(self):
684 for radicand in range(-1, 32):
685 if radicand < 0:
686 expected = None
687 else:
688 root = math.floor(math.sqrt(radicand))
689 remainder = radicand - root * root
690 expected = RootRemainder(root, remainder)
691 with self.subTest(radicand=radicand, expected=expected):
692 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
693 radicand = 2 << 64
694 root = 0x16A09E667
695 remainder = radicand - root * root
696 expected = RootRemainder(root, remainder)
697 with self.subTest(radicand=radicand, expected=expected):
698 self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
699
700 def test_on_fixed(self):
701 for signed in False, True:
702 for bit_width in range(1, 10):
703 for fract_width in range(bit_width):
704 for bits in range(1 << bit_width):
705 radicand = Fixed.from_bits(bits,
706 fract_width,
707 bit_width,
708 signed)
709 if radicand < 0:
710 continue
711 root = radicand.with_value(math.sqrt(float(radicand)))
712 remainder = radicand - root * root
713 expected = RootRemainder(root, remainder)
714 with self.subTest(radicand=repr(radicand),
715 expected=repr(expected)):
716 self.assertEqual(repr(fixed_sqrt(radicand)),
717 repr(expected))
718
719 def test_misc_cases(self):
720 test_cases = [
721 # radicand, expected
722 (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
723 (Fixed(2, 30, 32, False),
724 "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
725 ]
726 for radicand, expected in test_cases:
727 with self.subTest(radicand=str(radicand), expected=expected):
728 self.assertEqual(str(fixed_sqrt(radicand)), expected)
729
730
731 class TestFixedSqrt(unittest.TestCase):
732 def helper(self, log2_radix):
733 for bit_width in range(1, 8):
734 for fract_width in range(bit_width):
735 for radicand_bits in range(1 << bit_width):
736 radicand = Fixed.from_bits(radicand_bits,
737 fract_width,
738 bit_width,
739 False)
740 root_remainder = fixed_sqrt(radicand)
741 with self.subTest(radicand=repr(radicand),
742 root_remainder=repr(root_remainder),
743 log2_radix=log2_radix):
744 obj = FixedSqrt(radicand, log2_radix)
745 for _ in range(250 * bit_width):
746 self.assertEqual(obj.root * obj.root,
747 obj.root_squared)
748 self.assertGreaterEqual(obj.radicand,
749 obj.root_squared)
750 if obj.calculate_stage():
751 break
752 else:
753 self.fail("infinite loop")
754 self.assertEqual(obj.root * obj.root,
755 obj.root_squared)
756 self.assertGreaterEqual(obj.radicand,
757 obj.root_squared)
758 self.assertEqual(obj.remainder,
759 obj.radicand - obj.root_squared)
760 self.assertEqual(obj.root, root_remainder.root)
761 self.assertEqual(obj.remainder,
762 root_remainder.remainder)
763
764 def test_radix_2(self):
765 self.helper(1)
766
767 def test_radix_4(self):
768 self.helper(2)
769
770 def test_radix_8(self):
771 self.helper(3)
772
773 def test_radix_16(self):
774 self.helper(4)
775
776
777 class TestFixedRSqrtFn(unittest.TestCase):
778 def test2(self):
779 for bits in range(1, 1 << 5):
780 radicand = Fixed.from_bits(bits, 5, 12, False)
781 float_root = 1 / math.sqrt(float(radicand))
782 root = radicand.with_value(float_root)
783 remainder = 1 - root * root * radicand
784 expected = RootRemainder(root, remainder)
785 with self.subTest(radicand=repr(radicand),
786 expected=repr(expected)):
787 self.assertEqual(repr(fixed_rsqrt(radicand)),
788 repr(expected))
789
790 def test(self):
791 for signed in False, True:
792 for bit_width in range(1, 10):
793 for fract_width in range(bit_width):
794 for bits in range(1 << bit_width):
795 radicand = Fixed.from_bits(bits,
796 fract_width,
797 bit_width,
798 signed)
799 if radicand <= 0:
800 continue
801 float_root = 1 / math.sqrt(float(radicand))
802 max_value = radicand.with_bits(
803 (1 << (bit_width - signed)) - 1)
804 if float_root > float(max_value):
805 root = max_value
806 else:
807 root = radicand.with_value(float_root)
808 remainder = 1 - root * root * radicand
809 expected = RootRemainder(root, remainder)
810 with self.subTest(radicand=repr(radicand),
811 expected=repr(expected)):
812 self.assertEqual(repr(fixed_rsqrt(radicand)),
813 repr(expected))
814
815 def test_misc_cases(self):
816 test_cases = [
817 # radicand, expected
818 (Fixed(0.5, 30, 32, False),
819 "RootRemainder(fixed:0x1.6a09e664, "
820 "fixed:0x0.0000000596d014780000000)")
821 ]
822 for radicand, expected in test_cases:
823 with self.subTest(radicand=str(radicand), expected=expected):
824 self.assertEqual(str(fixed_rsqrt(radicand)), expected)
825
826
827 class TestFixedRSqrt(unittest.TestCase):
828 def helper(self, log2_radix):
829 for bit_width in range(1, 8):
830 for fract_width in range(bit_width):
831 for radicand_bits in range(1, 1 << bit_width):
832 radicand = Fixed.from_bits(radicand_bits,
833 fract_width,
834 bit_width,
835 False)
836 root_remainder = fixed_rsqrt(radicand)
837 with self.subTest(radicand=repr(radicand),
838 root_remainder=repr(root_remainder),
839 log2_radix=log2_radix):
840 obj = FixedRSqrt(radicand, log2_radix)
841 for _ in range(250 * bit_width):
842 self.assertEqual(obj.radicand * obj.root,
843 obj.radicand_root)
844 self.assertEqual(obj.radicand_root * obj.root,
845 obj.radicand_root_squared)
846 self.assertGreaterEqual(1,
847 obj.radicand_root_squared)
848 if obj.calculate_stage():
849 break
850 else:
851 self.fail("infinite loop")
852 self.assertEqual(obj.radicand * obj.root,
853 obj.radicand_root)
854 self.assertEqual(obj.radicand_root * obj.root,
855 obj.radicand_root_squared)
856 self.assertGreaterEqual(1,
857 obj.radicand_root_squared)
858 self.assertEqual(obj.remainder,
859 1 - obj.radicand_root_squared)
860 self.assertEqual(obj.root, root_remainder.root)
861 self.assertEqual(obj.remainder,
862 root_remainder.remainder)
863
864 def test_radix_2(self):
865 self.helper(1)
866
867 def test_radix_4(self):
868 self.helper(2)
869
870 def test_radix_8(self):
871 self.helper(3)
872
873 def test_radix_16(self):
874 self.helper(4)