1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 """ Algorithms for div/rem/sqrt/rsqrt.
6 code for simulating/testing the various algorithms
9 from nmigen
.hdl
.ast
import Const
12 def div_rem(dividend
, divisor
, bit_width
, signed
):
13 """ Compute the quotient/remainder following the RISC-V M extension.
15 NOT the same as the // or % operators
17 dividend
= Const
.normalize(dividend
, (bit_width
, signed
))
18 divisor
= Const
.normalize(divisor
, (bit_width
, signed
))
23 quotient
= abs(dividend
) // abs(divisor
)
24 remainder
= abs(dividend
) % abs(divisor
)
25 if (dividend
< 0) != (divisor
< 0):
28 remainder
= -remainder
29 quotient
= Const
.normalize(quotient
, (bit_width
, signed
))
30 remainder
= Const
.normalize(remainder
, (bit_width
, signed
))
31 return quotient
, remainder
35 """ Unsigned integer division/remainder following the RISC-V M extension.
37 NOT the same as the // or % operators
39 :attribute remainder: the remainder and/or dividend
40 :attribute divisor: the divisor
41 :attribute bit_width: the bit width of the inputs/outputs
42 :attribute log2_radix: the base-2 log of the division radix. The number of
43 bits of quotient that are calculated per pipeline stage.
44 :attribute quotient: the quotient
45 :attribute current_shift: the current bit index
48 def __init__(self
, dividend
, divisor
, bit_width
, log2_radix
=3):
49 """ Create an UnsignedDivRem.
51 :param dividend: the dividend/numerator
52 :param divisor: the divisor/denominator
53 :param bit_width: the bit width of the inputs/outputs
54 :param log2_radix: the base-2 log of the division radix. The number of
55 bits of quotient that are calculated per pipeline stage.
57 self
.remainder
= Const
.normalize(dividend
, (bit_width
, False))
58 self
.divisor
= Const
.normalize(divisor
, (bit_width
, False))
59 self
.bit_width
= bit_width
60 self
.log2_radix
= log2_radix
62 self
.current_shift
= bit_width
64 def calculate_stage(self
):
65 """ Calculate the next pipeline stage of the division.
67 :returns bool: True if this is the last pipeline stage.
69 if self
.current_shift
== 0:
71 log2_radix
= min(self
.log2_radix
, self
.current_shift
)
73 self
.current_shift
-= log2_radix
74 radix
= 1 << log2_radix
76 for i
in range(radix
):
77 v
= (self
.divisor
* i
) << self
.current_shift
78 remainders
.append(self
.remainder
- v
)
80 for i
in range(radix
):
81 if remainders
[i
] >= 0:
83 self
.remainder
= remainders
[quotient_bits
]
84 self
.quotient |
= quotient_bits
<< self
.current_shift
85 return self
.current_shift
== 0
88 """ Calculate the results of the division.
92 while not self
.calculate_stage():
98 """ integer division/remainder following the RISC-V M extension.
100 NOT the same as the // or % operators
102 :attribute dividend: the dividend
103 :attribute divisor: the divisor
104 :attribute signed: if the inputs/outputs are signed instead of unsigned
105 :attribute quotient: the quotient
106 :attribute remainder: the remainder
107 :attribute divider: the base UnsignedDivRem
110 def __init__(self
, dividend
, divisor
, bit_width
, signed
, log2_radix
=3):
113 :param dividend: the dividend/numerator
114 :param divisor: the divisor/denominator
115 :param bit_width: the bit width of the inputs/outputs
116 :param signed: if the inputs/outputs are signed instead of unsigned
117 :param log2_radix: the base-2 log of the division radix. The number of
118 bits of quotient that are calculated per pipeline stage.
120 self
.dividend
= Const
.normalize(dividend
, (bit_width
, signed
))
121 self
.divisor
= Const
.normalize(divisor
, (bit_width
, signed
))
125 self
.divider
= UnsignedDivRem(abs(dividend
), abs(divisor
),
126 bit_width
, log2_radix
)
128 def calculate_stage(self
):
129 """ Calculate the next pipeline stage of the division.
131 :returns bool: True if this is the last pipeline stage.
133 if not self
.divider
.calculate_stage():
135 divisor_sign
= self
.divisor
< 0
136 dividend_sign
= self
.dividend
< 0
137 if self
.divisor
!= 0 and divisor_sign
!= dividend_sign
:
138 quotient
= -self
.divider
.quotient
140 quotient
= self
.divider
.quotient
142 remainder
= -self
.divider
.remainder
144 remainder
= self
.divider
.remainder
145 bit_width
= self
.divider
.bit_width
146 self
.quotient
= Const
.normalize(quotient
, (bit_width
, self
.signed
))
147 self
.remainder
= Const
.normalize(remainder
, (bit_width
, self
.signed
))
152 """ Fixed-point number.
154 the value is bits * 2 ** -fract_width
156 :attribute bits: the bits of the fixed-point number
157 :attribute fract_width: the number of bits in the fractional portion
158 :attribute bit_width: the total number of bits
159 :attribute signed: if the type is signed
163 def from_bits(bits
, fract_width
, bit_width
, signed
):
164 """ Create a new Fixed.
166 :param bits: the bits of the fixed-point number
167 :param fract_width: the number of bits in the fractional portion
168 :param bit_width: the total number of bits
169 :param signed: if the type is signed
171 retval
= Fixed(0, fract_width
, bit_width
, signed
)
172 retval
.bits
= Const
.normalize(bits
, (bit_width
, signed
))
175 def __init__(self
, value
, fract_width
, bit_width
, signed
):
176 """ Create a new Fixed.
178 :param value: the value of the fixed-point number
179 :param fract_width: the number of bits in the fractional portion
180 :param bit_width: the total number of bits
181 :param signed: if the type is signed
183 assert fract_width
>= 0
185 if isinstance(value
, Fixed
):
186 if fract_width
< value
.fract_width
:
187 bits
= value
.bits
>> (value
.fract_width
- fract_width
)
189 bits
= value
.bits
<< (fract_width
- value
.fract_width
)
190 elif isinstance(value
, int):
191 bits
= value
<< fract_width
193 bits
= floor(value
* 2 ** fract_width
)
194 self
.bits
= Const
.normalize(bits
, (bit_width
, signed
))
195 self
.fract_width
= fract_width
196 self
.bit_width
= bit_width
200 """ Get representation."""
201 return f
"Fixed({self.bits}, {self.fract_width}, {self.bit_width})"
204 """ Truncate to integer."""
206 return self
.__ceil
__()
207 return self
.__floor
__()
210 """ Truncate to integer."""
211 return self
.__trunc
__()
214 """ Convert to float."""
215 return self
.bits
* 2 ** -self
.fract_width
218 """ Floor to integer."""
219 return self
.bits
>> self
.fract_width
222 """ Ceil to integer."""
223 return -((-self
.bits
) >> self
.fract_width
)
227 return self
.from_bits(-self
.bits
, self
.fract_width
,
228 self
.bit_width
, self
.signed
)
231 """ Unary Positive."""
235 """ Absolute Value."""
236 return self
.from_bits(abs(self
.bits
), self
.fract_width
,
237 self
.bit_width
, self
.signed
)
239 def __invert__(self
):
241 return self
.from_bits(~self
.bits
, self
.fract_width
,
242 self
.bit_width
, self
.signed
)
244 def _binary_op(self
, rhs
, operation
, full
=False):
245 """ Handle binary arithmetic operators. """
246 if isinstance(rhs
, int):
249 int_width
= self
.bit_width
- self
.fract_width
250 elif isinstance(rhs
, Fixed
):
251 if self
.signed
!= rhs
.signed
:
252 return TypeError("signedness must match")
253 rhs_fract_width
= rhs
.fract_width
255 int_width
= max(self
.bit_width
- self
.fract_width
,
256 rhs
.bit_width
- rhs
.fract_width
)
258 return NotImplemented
259 fract_width
= max(self
.fract_width
, rhs_fract_width
)
260 rhs_bits
<<= fract_width
- rhs_fract_width
261 lhs_bits
= self
.bits
<< fract_width
- self
.fract_width
262 bit_width
= int_width
+ fract_width
264 return operation(lhs_bits
, rhs_bits
,
265 fract_width
, bit_width
, self
.signed
)
266 bits
= operation(lhs_bits
, rhs_bits
,
268 return self
.from_bits(bits
, fract_width
, bit_width
, self
.signed
)
270 def __add__(self
, rhs
):
272 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
+ rhs
)
274 def __radd__(self
, lhs
):
275 """ Reverse Addition."""
276 return self
.__add
__(lhs
)
278 def __sub__(self
, rhs
):
280 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
- rhs
)
282 def __rsub__(self
, lhs
):
283 """ Reverse Subtraction."""
284 # note swapped argument and parameter order
285 return self
._binary
_op
(lhs
, lambda rhs
, lhs
, fract_width
: lhs
- rhs
)
287 def __and__(self
, rhs
):
289 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
& rhs
)
291 def __rand__(self
, lhs
):
292 """ Reverse Bitwise And."""
293 return self
.__and
__(lhs
)
295 def __or__(self
, rhs
):
297 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs | rhs
)
299 def __ror__(self
, lhs
):
300 """ Reverse Bitwise Or."""
301 return self
.__or
__(lhs
)
303 def __xor__(self
, rhs
):
305 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs ^ rhs
)
307 def __rxor__(self
, lhs
):
308 """ Reverse Bitwise Xor."""
309 return self
.__xor
__(lhs
)
311 def __mul__(self
, rhs
):
312 """ Multiplication. """
313 if isinstance(rhs
, int):
316 int_width
= self
.bit_width
- self
.fract_width
317 elif isinstance(rhs
, Fixed
):
318 if self
.signed
!= rhs
.signed
:
319 return TypeError("signedness must match")
320 rhs_fract_width
= rhs
.fract_width
322 int_width
= (self
.bit_width
- self
.fract_width
323 + rhs
.bit_width
- rhs
.fract_width
)
325 return NotImplemented
326 fract_width
= self
.fract_width
+ rhs_fract_width
327 bit_width
= int_width
+ fract_width
328 bits
= self
.bits
* rhs_bits
329 return self
.from_bits(bits
, fract_width
, bit_width
, self
.signed
)
332 def _cmp_impl(lhs
, rhs
, fract_width
, bit_width
, signed
):
340 """ Compare self with rhs.
342 :returns int: returns -1 if self is less than rhs, 0 if they're equal,
343 and 1 for greater than.
344 Returns NotImplemented for unimplemented cases
346 return self
._binary
_op
(rhs
, self
._cmp
_impl
, full
=True)
348 def __lt__(self
, rhs
):
350 return self
.cmp(rhs
) < 0
352 def __le__(self
, rhs
):
353 """ Less Than or Equal."""
354 return self
.cmp(rhs
) <= 0
356 def __eq__(self
, rhs
):
358 return self
.cmp(rhs
) == 0
360 def __ne__(self
, rhs
):
362 return self
.cmp(rhs
) != 0
364 def __gt__(self
, rhs
):
366 return self
.cmp(rhs
) > 0
368 def __ge__(self
, rhs
):
369 """ Greater Than or Equal."""
370 return self
.cmp(rhs
) >= 0
372 def __bool__(self
, rhs
):
373 """ Convert to bool."""
374 return bool(self
.bits
)
377 """ Get text representation."""
378 # don't just use self.__float__() in order to work with numbers more
385 int_part
= bits
>> self
.fract_width
386 fract_part
= bits
& ~
(-1 << self
.fract_width
)
387 # round up fract_width to nearest multiple of 4
388 fract_width
= (self
.fract_width
+ 3) & ~
3
389 fract_part
<<= (fract_width
- self
.fract_width
)
390 fract_width_in_hex_digits
= fract_width
/ 4
391 retval
+= f
"{int_part:x}."
392 retval
+= f
"{fract_part:x}".zfill(fract_width_in_hex_digits
)
398 raise NotImplementedError()
408 raise NotImplementedError()