1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 """ Algorithms for div/rem/sqrt/rsqrt.
6 code for simulating/testing the various algorithms
9 from nmigen
.hdl
.ast
import Const
14 def div_rem(dividend
, divisor
, bit_width
, signed
):
15 """ Compute the quotient/remainder following the RISC-V M extension.
17 NOT the same as the // or % operators
19 dividend
= Const
.normalize(dividend
, (bit_width
, signed
))
20 divisor
= Const
.normalize(divisor
, (bit_width
, signed
))
25 quotient
= abs(dividend
) // abs(divisor
)
26 remainder
= abs(dividend
) % abs(divisor
)
27 if (dividend
< 0) != (divisor
< 0):
30 remainder
= -remainder
31 quotient
= Const
.normalize(quotient
, (bit_width
, signed
))
32 remainder
= Const
.normalize(remainder
, (bit_width
, signed
))
33 return quotient
, remainder
37 """ Unsigned integer division/remainder following the RISC-V M extension.
39 NOT the same as the // or % operators
41 :attribute dividend: the dividend
42 :attribute remainder: the remainder
43 :attribute divisor: the divisor
44 :attribute bit_width: the bit width of the inputs/outputs
45 :attribute log2_radix: the base-2 log of the division radix. The number of
46 bits of quotient that are calculated per pipeline stage.
47 :attribute quotient: the quotient
48 :attribute quotient_times_divisor: ``quotient * divisor``
49 :attribute current_shift: the current bit index
52 def __init__(self
, dividend
, divisor
, bit_width
, log2_radix
=3):
53 """ Create an UnsignedDivRem.
55 :param dividend: the dividend/numerator
56 :param divisor: the divisor/denominator
57 :param bit_width: the bit width of the inputs/outputs
58 :param log2_radix: the base-2 log of the division radix. The number of
59 bits of quotient that are calculated per pipeline stage.
61 self
.dividend
= Const
.normalize(dividend
, (bit_width
, False))
62 self
.divisor
= Const
.normalize(divisor
, (bit_width
, False))
63 self
.bit_width
= bit_width
64 self
.log2_radix
= log2_radix
66 self
.quotient_times_divisor
= self
.quotient
* self
.divisor
67 self
.current_shift
= bit_width
69 def calculate_stage(self
):
70 """ Calculate the next pipeline stage of the division.
72 :returns bool: True if this is the last pipeline stage.
74 if self
.current_shift
== 0:
76 log2_radix
= min(self
.log2_radix
, self
.current_shift
)
78 self
.current_shift
-= log2_radix
79 radix
= 1 << log2_radix
81 for i
in range(radix
):
82 v
= self
.quotient_times_divisor
83 v
+= (self
.divisor
* i
) << self
.current_shift
84 trial_values
.append(v
)
86 next_product
= self
.quotient_times_divisor
87 for i
in range(radix
):
88 if self
.dividend
>= trial_values
[i
]:
90 next_product
= trial_values
[i
]
91 self
.quotient_times_divisor
= next_product
92 self
.quotient |
= quotient_bits
<< self
.current_shift
93 if self
.current_shift
== 0:
94 self
.remainder
= self
.dividend
- self
.quotient_times_divisor
99 """ Calculate the results of the division.
103 while not self
.calculate_stage():
109 """ integer division/remainder following the RISC-V M extension.
111 NOT the same as the // or % operators
113 :attribute dividend: the dividend
114 :attribute divisor: the divisor
115 :attribute signed: if the inputs/outputs are signed instead of unsigned
116 :attribute quotient: the quotient
117 :attribute remainder: the remainder
118 :attribute divider: the base UnsignedDivRem
121 def __init__(self
, dividend
, divisor
, bit_width
, signed
, log2_radix
=3):
124 :param dividend: the dividend/numerator
125 :param divisor: the divisor/denominator
126 :param bit_width: the bit width of the inputs/outputs
127 :param signed: if the inputs/outputs are signed instead of unsigned
128 :param log2_radix: the base-2 log of the division radix. The number of
129 bits of quotient that are calculated per pipeline stage.
131 self
.dividend
= Const
.normalize(dividend
, (bit_width
, signed
))
132 self
.divisor
= Const
.normalize(divisor
, (bit_width
, signed
))
136 self
.divider
= UnsignedDivRem(abs(dividend
), abs(divisor
),
137 bit_width
, log2_radix
)
139 def calculate_stage(self
):
140 """ Calculate the next pipeline stage of the division.
142 :returns bool: True if this is the last pipeline stage.
144 if not self
.divider
.calculate_stage():
146 divisor_sign
= self
.divisor
< 0
147 dividend_sign
= self
.dividend
< 0
148 if self
.divisor
!= 0 and divisor_sign
!= dividend_sign
:
149 quotient
= -self
.divider
.quotient
151 quotient
= self
.divider
.quotient
153 remainder
= -self
.divider
.remainder
155 remainder
= self
.divider
.remainder
156 bit_width
= self
.divider
.bit_width
157 self
.quotient
= Const
.normalize(quotient
, (bit_width
, self
.signed
))
158 self
.remainder
= Const
.normalize(remainder
, (bit_width
, self
.signed
))
163 """ Fixed-point number.
165 the value is bits * 2 ** -fract_width
167 :attribute bits: the bits of the fixed-point number
168 :attribute fract_width: the number of bits in the fractional portion
169 :attribute bit_width: the total number of bits
170 :attribute signed: if the type is signed
174 def from_bits(bits
, fract_width
, bit_width
, signed
):
175 """ Create a new Fixed.
177 :param bits: the bits of the fixed-point number
178 :param fract_width: the number of bits in the fractional portion
179 :param bit_width: the total number of bits
180 :param signed: if the type is signed
182 retval
= Fixed(0, fract_width
, bit_width
, signed
)
183 retval
.bits
= Const
.normalize(bits
, (bit_width
, signed
))
186 def __init__(self
, value
, fract_width
, bit_width
, signed
):
187 """ Create a new Fixed.
189 Note: ``value`` is not the same as ``bits``. To put a particular number
190 in ``bits``, use ``Fixed.from_bits``.
192 :param value: the value of the fixed-point number
193 :param fract_width: the number of bits in the fractional portion
194 :param bit_width: the total number of bits
195 :param signed: if the type is signed
197 assert fract_width
>= 0
199 if isinstance(value
, Fixed
):
200 if fract_width
< value
.fract_width
:
201 bits
= value
.bits
>> (value
.fract_width
- fract_width
)
203 bits
= value
.bits
<< (fract_width
- value
.fract_width
)
204 elif isinstance(value
, int):
205 bits
= value
<< fract_width
207 bits
= math
.floor(value
* 2 ** fract_width
)
208 self
.bits
= Const
.normalize(bits
, (bit_width
, signed
))
209 self
.fract_width
= fract_width
210 self
.bit_width
= bit_width
213 def with_bits(self
, bits
):
214 """ Create a new Fixed with the specified bits.
216 :param bits: the new bits.
217 :returns Fixed: the new Fixed.
219 return self
.from_bits(bits
,
224 def with_value(self
, value
):
225 """ Create a new Fixed with the specified value.
227 :param value: the new value.
228 :returns Fixed: the new Fixed.
236 """ Get representation."""
237 retval
= f
"Fixed.from_bits({self.bits}, {self.fract_width}, "
238 return retval
+ f
"{self.bit_width}, {self.signed})"
241 """ Truncate to integer."""
243 return self
.__ceil
__()
244 return self
.__floor
__()
247 """ Truncate to integer."""
248 return self
.__trunc
__()
251 """ Convert to float."""
252 return self
.bits
* 2.0 ** -self
.fract_width
255 """ Floor to integer."""
256 return self
.bits
>> self
.fract_width
259 """ Ceil to integer."""
260 return -((-self
.bits
) >> self
.fract_width
)
264 return self
.from_bits(-self
.bits
, self
.fract_width
,
265 self
.bit_width
, self
.signed
)
268 """ Unary Positive."""
272 """ Absolute Value."""
273 return self
.from_bits(abs(self
.bits
), self
.fract_width
,
274 self
.bit_width
, self
.signed
)
276 def __invert__(self
):
278 return self
.from_bits(~self
.bits
, self
.fract_width
,
279 self
.bit_width
, self
.signed
)
281 def _binary_op(self
, rhs
, operation
, full
=False):
282 """ Handle binary arithmetic operators. """
283 if isinstance(rhs
, int):
286 int_width
= self
.bit_width
- self
.fract_width
287 elif isinstance(rhs
, Fixed
):
288 if self
.signed
!= rhs
.signed
:
289 return TypeError("signedness must match")
290 rhs_fract_width
= rhs
.fract_width
292 int_width
= max(self
.bit_width
- self
.fract_width
,
293 rhs
.bit_width
- rhs
.fract_width
)
295 return NotImplemented
296 fract_width
= max(self
.fract_width
, rhs_fract_width
)
297 rhs_bits
<<= fract_width
- rhs_fract_width
298 lhs_bits
= self
.bits
<< fract_width
- self
.fract_width
299 bit_width
= int_width
+ fract_width
301 return operation(lhs_bits
, rhs_bits
,
302 fract_width
, bit_width
, self
.signed
)
303 bits
= operation(lhs_bits
, rhs_bits
,
305 return self
.from_bits(bits
, fract_width
, bit_width
, self
.signed
)
307 def __add__(self
, rhs
):
309 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
+ rhs
)
311 def __radd__(self
, lhs
):
312 """ Reverse Addition."""
313 return self
.__add
__(lhs
)
315 def __sub__(self
, rhs
):
317 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
- rhs
)
319 def __rsub__(self
, lhs
):
320 """ Reverse Subtraction."""
321 # note swapped argument and parameter order
322 return self
._binary
_op
(lhs
, lambda rhs
, lhs
, fract_width
: lhs
- rhs
)
324 def __and__(self
, rhs
):
326 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
& rhs
)
328 def __rand__(self
, lhs
):
329 """ Reverse Bitwise And."""
330 return self
.__and
__(lhs
)
332 def __or__(self
, rhs
):
334 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs | rhs
)
336 def __ror__(self
, lhs
):
337 """ Reverse Bitwise Or."""
338 return self
.__or
__(lhs
)
340 def __xor__(self
, rhs
):
342 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs ^ rhs
)
344 def __rxor__(self
, lhs
):
345 """ Reverse Bitwise Xor."""
346 return self
.__xor
__(lhs
)
348 def __mul__(self
, rhs
):
349 """ Multiplication. """
350 if isinstance(rhs
, int):
353 int_width
= self
.bit_width
- self
.fract_width
354 elif isinstance(rhs
, Fixed
):
355 if self
.signed
!= rhs
.signed
:
356 return TypeError("signedness must match")
357 rhs_fract_width
= rhs
.fract_width
359 int_width
= (self
.bit_width
- self
.fract_width
360 + rhs
.bit_width
- rhs
.fract_width
)
362 return NotImplemented
363 fract_width
= self
.fract_width
+ rhs_fract_width
364 bit_width
= int_width
+ fract_width
365 bits
= self
.bits
* rhs_bits
366 return self
.from_bits(bits
, fract_width
, bit_width
, self
.signed
)
368 def __rmul__(self
, rhs
):
369 """ Reverse Multiplication. """
370 return self
.__mul
__(rhs
)
373 def _cmp_impl(lhs
, rhs
, fract_width
, bit_width
, signed
):
381 """ Compare self with rhs.
383 :returns int: returns -1 if self is less than rhs, 0 if they're equal,
384 and 1 for greater than.
385 Returns NotImplemented for unimplemented cases
387 return self
._binary
_op
(rhs
, self
._cmp
_impl
, full
=True)
389 def __lt__(self
, rhs
):
391 return self
.cmp(rhs
) < 0
393 def __le__(self
, rhs
):
394 """ Less Than or Equal."""
395 return self
.cmp(rhs
) <= 0
397 def __eq__(self
, rhs
):
399 return self
.cmp(rhs
) == 0
401 def __ne__(self
, rhs
):
403 return self
.cmp(rhs
) != 0
405 def __gt__(self
, rhs
):
407 return self
.cmp(rhs
) > 0
409 def __ge__(self
, rhs
):
410 """ Greater Than or Equal."""
411 return self
.cmp(rhs
) >= 0
414 """ Convert to bool."""
415 return bool(self
.bits
)
418 """ Get text representation."""
419 # don't just use self.__float__() in order to work with numbers more
426 int_part
= bits
>> self
.fract_width
427 fract_part
= bits
& ~
(-1 << self
.fract_width
)
428 # round up fract_width to nearest multiple of 4
429 fract_width
= (self
.fract_width
+ 3) & ~
3
430 fract_part
<<= (fract_width
- self
.fract_width
)
431 fract_width_in_hex_digits
= fract_width
// 4
432 retval
+= f
"0x{int_part:x}."
433 if fract_width_in_hex_digits
!= 0:
434 retval
+= f
"{fract_part:x}".zfill(fract_width_in_hex_digits
)
439 """ A polynomial root and remainder.
441 :attribute root: the polynomial root.
442 :attribute remainder: the remainder.
445 def __init__(self
, root
, remainder
):
446 """ Create a new RootRemainder.
448 :param root: the polynomial root.
449 :param remainder: the remainder.
452 self
.remainder
= remainder
455 """ Get the representation as a string. """
456 return f
"RootRemainder({repr(self.root)}, {repr(self.remainder)})"
459 """ Convert to a string. """
460 return f
"RootRemainder({str(self.root)}, {str(self.remainder)})"
463 def fixed_sqrt(radicand
):
464 """ Compute the Square Root and Remainder.
466 Solves the polynomial ``radicand - x * x == 0``
468 :param radicand: the ``Fixed`` to take the square root of.
469 :returns RootRemainder:
471 # Written for correctness, not speed
474 is_int
= isinstance(radicand
, int)
476 radicand
= Fixed(radicand
, 0, radicand
.bit_length() + 1, True)
477 elif not isinstance(radicand
, Fixed
):
480 def is_remainder_non_negative(root
):
481 return radicand
>= root
* root
483 root
= radicand
.with_bits(0)
484 for i
in reversed(range(root
.bit_width
)):
485 new_root
= root
.with_bits(root
.bits |
(1 << i
))
486 if new_root
< 0: # skip sign bit
488 if is_remainder_non_negative(new_root
):
490 remainder
= radicand
- root
* root
493 remainder
= int(remainder
)
494 return RootRemainder(root
, remainder
)
498 """ Fixed-point Square-Root/Remainder.
500 :attribute radicand: the radicand
501 :attribute root: the square root
502 :attribute root_squared: the square of ``root``
503 :attribute remainder: the remainder
504 :attribute log2_radix: the base-2 log of the operation radix. The number of
505 bits of root that are calculated per pipeline stage.
506 :attribute current_shift: the current bit index
509 def __init__(self
, radicand
, log2_radix
=3):
510 """ Create an FixedSqrt.
512 :param radicand: the radicand.
513 :param log2_radix: the base-2 log of the operation radix. The number of
514 bits of root that are calculated per pipeline stage.
516 assert isinstance(radicand
, Fixed
)
517 assert radicand
.signed
is False
518 self
.radicand
= radicand
519 self
.root
= radicand
.with_bits(0)
520 self
.root_squared
= self
.root
* self
.root
521 self
.remainder
= radicand
.with_bits(0) - self
.root_squared
522 self
.log2_radix
= log2_radix
523 self
.current_shift
= self
.root
.bit_width
525 def calculate_stage(self
):
526 """ Calculate the next pipeline stage of the operation.
528 :returns bool: True if this is the last pipeline stage.
530 if self
.current_shift
== 0:
532 log2_radix
= min(self
.log2_radix
, self
.current_shift
)
533 assert log2_radix
> 0
534 self
.current_shift
-= log2_radix
535 radix
= 1 << log2_radix
537 for i
in range(radix
):
538 v
= self
.root_squared
539 factor1
= Fixed
.from_bits(i
<< (self
.current_shift
+ 1),
540 self
.root
.fract_width
,
541 self
.root
.bit_width
+ 1 + log2_radix
,
543 v
+= self
.root
* factor1
544 factor2
= Fixed
.from_bits(i
<< self
.current_shift
,
545 self
.root
.fract_width
,
546 self
.root
.bit_width
+ log2_radix
,
548 v
+= factor2
* factor2
549 trial_squares
.append(self
.root_squared
.with_value(v
))
551 new_root_squared
= self
.root_squared
552 for i
in range(radix
):
553 if self
.radicand
>= trial_squares
[i
]:
555 new_root_squared
= trial_squares
[i
]
556 self
.root |
= Fixed
.from_bits(root_bits
<< self
.current_shift
,
557 self
.root
.fract_width
,
558 self
.root
.bit_width
+ log2_radix
,
560 self
.root_squared
= new_root_squared
561 if self
.current_shift
== 0:
562 self
.remainder
= self
.radicand
- self
.root_squared
567 """ Calculate the results of the square root.
571 while not self
.calculate_stage():
576 def fixed_rsqrt(radicand
):
577 """ Compute the Reciprocal Square Root and Remainder.
579 Solves the polynomial ``1 - x * x * radicand == 0``
581 :param radicand: the ``Fixed`` to take the reciprocal square root of.
582 :returns RootRemainder:
584 # Written for correctness, not speed
587 if not isinstance(radicand
, Fixed
):
590 def is_remainder_non_negative(root
):
591 return 1 >= root
* root
* radicand
593 root
= radicand
.with_bits(0)
594 for i
in reversed(range(root
.bit_width
)):
595 new_root
= root
.with_bits(root
.bits |
(1 << i
))
596 if new_root
< 0: # skip sign bit
598 if is_remainder_non_negative(new_root
):
600 remainder
= 1 - root
* root
* radicand
601 return RootRemainder(root
, remainder
)
605 """ Fixed-point Reciprocal-Square-Root/Remainder.
607 :attribute radicand: the radicand
608 :attribute root: the reciprocal square root
609 :attribute radicand_root: ``radicand * root``
610 :attribute radicand_root_squared: ``radicand * root * root``
611 :attribute remainder: the remainder
612 :attribute log2_radix: the base-2 log of the operation radix. The number of
613 bits of root that are calculated per pipeline stage.
614 :attribute current_shift: the current bit index
617 def __init__(self
, radicand
, log2_radix
=3):
618 """ Create an FixedRSqrt.
620 :param radicand: the radicand.
621 :param log2_radix: the base-2 log of the operation radix. The number of
622 bits of root that are calculated per pipeline stage.
624 assert isinstance(radicand
, Fixed
)
625 assert radicand
.signed
is False
626 self
.radicand
= radicand
627 self
.root
= radicand
.with_bits(0)
628 self
.radicand_root
= radicand
.with_bits(0) * self
.root
629 self
.radicand_root_squared
= self
.radicand_root
* self
.root
630 self
.remainder
= radicand
.with_bits(0) - self
.radicand_root_squared
631 self
.log2_radix
= log2_radix
632 self
.current_shift
= self
.root
.bit_width
634 def calculate_stage(self
):
635 """ Calculate the next pipeline stage of the operation.
637 :returns bool: True if this is the last pipeline stage.
639 if self
.current_shift
== 0:
641 log2_radix
= min(self
.log2_radix
, self
.current_shift
)
642 assert log2_radix
> 0
643 self
.current_shift
-= log2_radix
644 radix
= 1 << log2_radix
646 for i
in range(radix
):
647 v
= self
.radicand_root_squared
648 factor1
= Fixed
.from_bits(i
<< (self
.current_shift
+ 1),
649 self
.root
.fract_width
,
650 self
.root
.bit_width
+ 1 + log2_radix
,
652 v
+= self
.radicand_root
* factor1
653 factor2
= Fixed
.from_bits(i
<< self
.current_shift
,
654 self
.root
.fract_width
,
655 self
.root
.bit_width
+ log2_radix
,
657 v
+= self
.radicand
* factor2
* factor2
658 trial_values
.append(self
.radicand_root_squared
.with_value(v
))
660 new_radicand_root_squared
= self
.radicand_root_squared
661 for i
in range(radix
):
662 if 1 >= trial_values
[i
]:
664 new_radicand_root_squared
= trial_values
[i
]
665 v
= self
.radicand_root
666 v
+= self
.radicand
* Fixed
.from_bits(root_bits
<< self
.current_shift
,
667 self
.root
.fract_width
,
668 self
.root
.bit_width
+ log2_radix
,
670 self
.radicand_root
= self
.radicand_root
.with_value(v
)
671 self
.root |
= Fixed
.from_bits(root_bits
<< self
.current_shift
,
672 self
.root
.fract_width
,
673 self
.root
.bit_width
+ log2_radix
,
675 self
.radicand_root_squared
= new_radicand_root_squared
676 if self
.current_shift
== 0:
677 self
.remainder
= 1 - self
.radicand_root_squared
682 """ Calculate the results of the reciprocal square root.
686 while not self
.calculate_stage():
691 class Operation(enum
.Enum
):
692 """ Operation for ``FixedUDivRemSqrtRSqrt``. """
694 UDivRem
= "unsigned-divide/remainder"
695 SqrtRem
= "square-root/remainder"
696 RSqrtRem
= "reciprocal-square-root/remainder"
699 class FixedUDivRemSqrtRSqrt
:
700 """ Combined class for computing fixed-point unsigned div/rem/sqrt/rsqrt.
702 Algorithm based on ``UnsignedDivRem``, ``FixedSqrt``, and ``FixedRSqrt``.
706 ``dividend == quotient_root * divisor_radicand``
708 ``divisor_radicand == quotient_root * quotient_root``
710 ``1 == quotient_root * quotient_root * divisor_radicand``
712 The remainder is the left-hand-side of the comparison minus the
713 right-hand-side of the comparison in the above formulas.
715 Important: not all variables have the same bit-width or fract-width. For
716 instance, ``dividend`` has a bit-width of ``bit_width + fract_width``
717 and a fract-width of ``2 * fract_width`` bits.
719 :attribute dividend: dividend for div/rem. Variable with a bit-width of
720 ``bit_width + fract_width`` and a fract-width of ``fract_width * 2``
722 :attribute divisor_radicand: divisor for div/rem and radicand for
723 sqrt/rsqrt. Variable with a bit-width of ``bit_width`` and a
724 fract-width of ``fract_width`` bits.
725 :attribute operation: the ``Operation`` to be computed.
726 :attribute quotient_root: the quotient or root part of the result of the
727 operation. Variable with a bit-width of ``bit_width`` and a fract-width
728 of ``fract_width`` bits.
729 :attribute remainder: the remainder part of the result of the operation.
730 Variable with a bit-width of ``bit_width * 3`` and a fract-width
731 of ``fract_width * 3`` bits.
732 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
733 Variable with a bit-width of ``bit_width * 2`` and a fract-width of
734 ``fract_width * 2`` bits.
735 :attribute compare_lhs: The left-hand-side of the comparison in the
736 equation to be solved. Variable with a bit-width of ``bit_width * 3``
737 and a fract-width of ``fract_width * 3`` bits.
738 :attribute compare_rhs: The right-hand-side of the comparison in the
739 equation to be solved. Variable with a bit-width of ``bit_width * 3``
740 and a fract-width of ``fract_width * 3`` bits.
741 :attribute bit_width: base bit-width. Constant int.
742 :attribute fract_width: base fract-width. Specifies location of base-2
743 radix point. Constant int.
744 :attribute log2_radix: number of bits of ``quotient_root`` that should be
745 computed per pipeline stage (invocation of ``calculate_stage``).
747 :attribute current_shift: the current bit index. Variable int.
757 """ Create a new ``FixedUDivRemSqrtRSqrt``.
759 :param dividend: ``dividend`` attribute's initializer.
760 :param divisor_radicand: ``divisor_radicand`` attribute's initializer.
761 :param operation: ``operation`` attribute's initializer.
762 :param bit_width: ``bit_width`` attribute's initializer.
763 :param fract_width: ``fract_width`` attribute's initializer.
764 :param log2_radix: ``log2_radix`` attribute's initializer.
767 assert fract_width
>= 0
768 assert fract_width
<= bit_width
769 assert log2_radix
> 0
770 self
.dividend
= Const
.normalize(dividend
,
771 (bit_width
+ fract_width
, False))
772 self
.divisor_radicand
= Const
.normalize(divisor_radicand
,
774 self
.quotient_root
= 0
775 self
.root_times_radicand
= 0
776 if operation
is Operation
.UDivRem
:
777 self
.compare_lhs
= self
.dividend
<< fract_width
778 elif operation
is Operation
.SqrtRem
:
779 self
.compare_lhs
= self
.divisor_radicand
<< (fract_width
* 2)
781 assert operation
is Operation
.RSqrtRem
782 self
.compare_lhs
= 1 << (fract_width
* 3)
784 self
.remainder
= self
.compare_lhs
785 self
.operation
= operation
786 self
.bit_width
= bit_width
787 self
.fract_width
= fract_width
788 self
.log2_radix
= log2_radix
789 self
.current_shift
= bit_width
791 def calculate_stage(self
):
792 """ Calculate the next pipeline stage of the operation.
794 :returns bool: True if this is the last pipeline stage.
796 if self
.current_shift
== 0:
798 log2_radix
= min(self
.log2_radix
, self
.current_shift
)
799 assert log2_radix
> 0
800 self
.current_shift
-= log2_radix
801 radix
= 1 << log2_radix
802 trial_compare_rhs_values
= []
803 for trial_bits
in range(radix
):
804 shifted_trial_bits
= trial_bits
<< self
.current_shift
805 shifted_trial_bits_sqrd
= shifted_trial_bits
* shifted_trial_bits
807 if self
.operation
is Operation
.UDivRem
:
808 factor1
= self
.divisor_radicand
* shifted_trial_bits
809 v
+= factor1
<< self
.fract_width
810 elif self
.operation
is Operation
.SqrtRem
:
811 factor1
= self
.quotient_root
* (shifted_trial_bits
<< 1)
812 v
+= factor1
<< self
.fract_width
813 factor2
= shifted_trial_bits_sqrd
814 v
+= factor2
<< self
.fract_width
816 assert self
.operation
is Operation
.RSqrtRem
817 factor1
= self
.root_times_radicand
* (shifted_trial_bits
<< 1)
819 factor2
= self
.divisor_radicand
* shifted_trial_bits_sqrd
821 trial_compare_rhs_values
.append(v
)
822 shifted_next_bits
= 0
823 next_compare_rhs
= trial_compare_rhs_values
[0]
824 for trial_bits
in range(radix
):
825 if self
.compare_lhs
>= trial_compare_rhs_values
[trial_bits
]:
826 shifted_next_bits
= trial_bits
<< self
.current_shift
827 next_compare_rhs
= trial_compare_rhs_values
[trial_bits
]
828 self
.root_times_radicand
+= self
.divisor_radicand
* shifted_next_bits
829 self
.compare_rhs
= next_compare_rhs
830 self
.quotient_root |
= shifted_next_bits
831 self
.remainder
= self
.compare_lhs
- self
.compare_rhs
832 return self
.current_shift
== 0
835 """ Calculate the results of the operation.
839 while not self
.calculate_stage():