change all uses of dataclass to plain_data
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / algorithm.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """ Algorithms for div/rem/sqrt/rsqrt.
5
6 code for simulating/testing the various algorithms
7 """
8
9 from nmigen.hdl.ast import Const
10
11
12 def div_rem(dividend, divisor, bit_width, signed):
13 """ Compute the quotient/remainder following the RISC-V M extension.
14
15 NOT the same as the // or % operators
16 """
17 dividend = Const.normalize(dividend, (bit_width, signed))
18 divisor = Const.normalize(divisor, (bit_width, signed))
19 if divisor == 0:
20 quotient = -1
21 remainder = dividend
22 else:
23 quotient = abs(dividend) // abs(divisor)
24 remainder = abs(dividend) % abs(divisor)
25 if (dividend < 0) != (divisor < 0):
26 quotient = -quotient
27 if dividend < 0:
28 remainder = -remainder
29 quotient = Const.normalize(quotient, (bit_width, signed))
30 remainder = Const.normalize(remainder, (bit_width, signed))
31 return quotient, remainder
32
33
34 class UnsignedDivRem:
35 """ Unsigned integer division/remainder following the RISC-V M extension.
36
37 NOT the same as the // or % operators
38
39 :attribute remainder: the remainder and/or dividend
40 :attribute divisor: the divisor
41 :attribute bit_width: the bit width of the inputs/outputs
42 :attribute log2_radix: the base-2 log of the division radix. The number of
43 bits of quotient that are calculated per pipeline stage.
44 :attribute quotient: the quotient
45 :attribute current_shift: the current bit index
46 """
47
48 def __init__(self, dividend, divisor, bit_width, log2_radix=3):
49 """ Create an UnsignedDivRem.
50
51 :param dividend: the dividend/numerator
52 :param divisor: the divisor/denominator
53 :param bit_width: the bit width of the inputs/outputs
54 :param log2_radix: the base-2 log of the division radix. The number of
55 bits of quotient that are calculated per pipeline stage.
56 """
57 self.remainder = Const.normalize(dividend, (bit_width, False))
58 self.divisor = Const.normalize(divisor, (bit_width, False))
59 self.bit_width = bit_width
60 self.log2_radix = log2_radix
61 self.quotient = 0
62 self.current_shift = bit_width
63
64 def calculate_stage(self):
65 """ Calculate the next pipeline stage of the division.
66
67 :returns bool: True if this is the last pipeline stage.
68 """
69 if self.current_shift == 0:
70 return True
71 log2_radix = min(self.log2_radix, self.current_shift)
72 assert log2_radix > 0
73 self.current_shift -= log2_radix
74 radix = 1 << log2_radix
75 remainders = []
76 for i in range(radix):
77 v = (self.divisor * i) << self.current_shift
78 remainders.append(self.remainder - v)
79 quotient_bits = 0
80 for i in range(radix):
81 if remainders[i] >= 0:
82 quotient_bits = i
83 self.remainder = remainders[quotient_bits]
84 self.quotient |= quotient_bits << self.current_shift
85 return self.current_shift == 0
86
87 def calculate(self):
88 """ Calculate the results of the division.
89
90 :returns: self
91 """
92 while not self.calculate_stage():
93 pass
94 return self
95
96
97 class DivRem:
98 """ integer division/remainder following the RISC-V M extension.
99
100 NOT the same as the // or % operators
101
102 :attribute dividend: the dividend
103 :attribute divisor: the divisor
104 :attribute signed: if the inputs/outputs are signed instead of unsigned
105 :attribute quotient: the quotient
106 :attribute remainder: the remainder
107 :attribute divider: the base UnsignedDivRem
108 """
109
110 def __init__(self, dividend, divisor, bit_width, signed, log2_radix=3):
111 """ Create a DivRem.
112
113 :param dividend: the dividend/numerator
114 :param divisor: the divisor/denominator
115 :param bit_width: the bit width of the inputs/outputs
116 :param signed: if the inputs/outputs are signed instead of unsigned
117 :param log2_radix: the base-2 log of the division radix. The number of
118 bits of quotient that are calculated per pipeline stage.
119 """
120 self.dividend = Const.normalize(dividend, (bit_width, signed))
121 self.divisor = Const.normalize(divisor, (bit_width, signed))
122 self.signed = signed
123 self.quotient = 0
124 self.remainder = 0
125 self.divider = UnsignedDivRem(abs(dividend), abs(divisor),
126 bit_width, log2_radix)
127
128 def calculate_stage(self):
129 """ Calculate the next pipeline stage of the division.
130
131 :returns bool: True if this is the last pipeline stage.
132 """
133 if not self.divider.calculate_stage():
134 return False
135 divisor_sign = self.divisor < 0
136 dividend_sign = self.dividend < 0
137 if self.divisor != 0 and divisor_sign != dividend_sign:
138 quotient = -self.divider.quotient
139 else:
140 quotient = self.divider.quotient
141 if dividend_sign:
142 remainder = -self.divider.remainder
143 else:
144 remainder = self.divider.remainder
145 bit_width = self.divider.bit_width
146 self.quotient = Const.normalize(quotient, (bit_width, self.signed))
147 self.remainder = Const.normalize(remainder, (bit_width, self.signed))
148 return True
149
150
151 class Fixed:
152 """ Fixed-point number.
153
154 the value is bits * 2 ** -fract_width
155
156 :attribute bits: the bits of the fixed-point number
157 :attribute fract_width: the number of bits in the fractional portion
158 :attribute bit_width: the total number of bits
159 :attribute signed: if the type is signed
160 """
161
162 @staticmethod
163 def from_bits(bits, fract_width, bit_width, signed):
164 """ Create a new Fixed.
165
166 :param bits: the bits of the fixed-point number
167 :param fract_width: the number of bits in the fractional portion
168 :param bit_width: the total number of bits
169 :param signed: if the type is signed
170 """
171 retval = Fixed(0, fract_width, bit_width, signed)
172 retval.bits = Const.normalize(bits, (bit_width, signed))
173 return retval
174
175 def __init__(self, value, fract_width, bit_width, signed):
176 """ Create a new Fixed.
177
178 :param value: the value of the fixed-point number
179 :param fract_width: the number of bits in the fractional portion
180 :param bit_width: the total number of bits
181 :param signed: if the type is signed
182 """
183 assert fract_width >= 0
184 assert bit_width > 0
185 if isinstance(value, Fixed):
186 if fract_width < value.fract_width:
187 bits = value.bits >> (value.fract_width - fract_width)
188 else:
189 bits = value.bits << (fract_width - value.fract_width)
190 elif isinstance(value, int):
191 bits = value << fract_width
192 else:
193 bits = floor(value * 2 ** fract_width)
194 self.bits = Const.normalize(bits, (bit_width, signed))
195 self.fract_width = fract_width
196 self.bit_width = bit_width
197 self.signed = signed
198
199 def __repr__(self):
200 """ Get representation."""
201 return f"Fixed({self.bits}, {self.fract_width}, {self.bit_width})"
202
203 def __trunc__(self):
204 """ Truncate to integer."""
205 if self.bits < 0:
206 return self.__ceil__()
207 return self.__floor__()
208
209 def __int__(self):
210 """ Truncate to integer."""
211 return self.__trunc__()
212
213 def __float__(self):
214 """ Convert to float."""
215 return self.bits * 2 ** -self.fract_width
216
217 def __floor__(self):
218 """ Floor to integer."""
219 return self.bits >> self.fract_width
220
221 def __ceil__(self):
222 """ Ceil to integer."""
223 return -((-self.bits) >> self.fract_width)
224
225 def __neg__(self):
226 """ Negate."""
227 return self.from_bits(-self.bits, self.fract_width,
228 self.bit_width, self.signed)
229
230 def __pos__(self):
231 """ Unary Positive."""
232 return self
233
234 def __abs__(self):
235 """ Absolute Value."""
236 return self.from_bits(abs(self.bits), self.fract_width,
237 self.bit_width, self.signed)
238
239 def __invert__(self):
240 """ Inverse."""
241 return self.from_bits(~self.bits, self.fract_width,
242 self.bit_width, self.signed)
243
244 def _binary_op(self, rhs, operation, full=False):
245 """ Handle binary arithmetic operators. """
246 if isinstance(rhs, int):
247 rhs_fract_width = 0
248 rhs_bits = rhs
249 int_width = self.bit_width - self.fract_width
250 elif isinstance(rhs, Fixed):
251 if self.signed != rhs.signed:
252 return TypeError("signedness must match")
253 rhs_fract_width = rhs.fract_width
254 rhs_bits = rhs.bits
255 int_width = max(self.bit_width - self.fract_width,
256 rhs.bit_width - rhs.fract_width)
257 else:
258 return NotImplemented
259 fract_width = max(self.fract_width, rhs_fract_width)
260 rhs_bits <<= fract_width - rhs_fract_width
261 lhs_bits = self.bits << fract_width - self.fract_width
262 bit_width = int_width + fract_width
263 if full:
264 return operation(lhs_bits, rhs_bits,
265 fract_width, bit_width, self.signed)
266 bits = operation(lhs_bits, rhs_bits,
267 fract_width)
268 return self.from_bits(bits, fract_width, bit_width, self.signed)
269
270 def __add__(self, rhs):
271 """ Addition."""
272 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs + rhs)
273
274 def __radd__(self, lhs):
275 """ Reverse Addition."""
276 return self.__add__(lhs)
277
278 def __sub__(self, rhs):
279 """ Subtraction."""
280 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs - rhs)
281
282 def __rsub__(self, lhs):
283 """ Reverse Subtraction."""
284 # note swapped argument and parameter order
285 return self._binary_op(lhs, lambda rhs, lhs, fract_width: lhs - rhs)
286
287 def __and__(self, rhs):
288 """ Bitwise And."""
289 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs & rhs)
290
291 def __rand__(self, lhs):
292 """ Reverse Bitwise And."""
293 return self.__and__(lhs)
294
295 def __or__(self, rhs):
296 """ Bitwise Or."""
297 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs | rhs)
298
299 def __ror__(self, lhs):
300 """ Reverse Bitwise Or."""
301 return self.__or__(lhs)
302
303 def __xor__(self, rhs):
304 """ Bitwise Xor."""
305 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs ^ rhs)
306
307 def __rxor__(self, lhs):
308 """ Reverse Bitwise Xor."""
309 return self.__xor__(lhs)
310
311 def __mul__(self, rhs):
312 """ Multiplication. """
313 if isinstance(rhs, int):
314 rhs_fract_width = 0
315 rhs_bits = rhs
316 int_width = self.bit_width - self.fract_width
317 elif isinstance(rhs, Fixed):
318 if self.signed != rhs.signed:
319 return TypeError("signedness must match")
320 rhs_fract_width = rhs.fract_width
321 rhs_bits = rhs.bits
322 int_width = (self.bit_width - self.fract_width
323 + rhs.bit_width - rhs.fract_width)
324 else:
325 return NotImplemented
326 fract_width = self.fract_width + rhs_fract_width
327 bit_width = int_width + fract_width
328 bits = self.bits * rhs_bits
329 return self.from_bits(bits, fract_width, bit_width, self.signed)
330
331 @staticmethod
332 def _cmp_impl(lhs, rhs, fract_width, bit_width, signed):
333 if lhs < rhs:
334 return -1
335 elif lhs == rhs:
336 return 0
337 return 1
338
339 def cmp(self, rhs):
340 """ Compare self with rhs.
341
342 :returns int: returns -1 if self is less than rhs, 0 if they're equal,
343 and 1 for greater than.
344 Returns NotImplemented for unimplemented cases
345 """
346 return self._binary_op(rhs, self._cmp_impl, full=True)
347
348 def __lt__(self, rhs):
349 """ Less Than."""
350 return self.cmp(rhs) < 0
351
352 def __le__(self, rhs):
353 """ Less Than or Equal."""
354 return self.cmp(rhs) <= 0
355
356 def __eq__(self, rhs):
357 """ Equal."""
358 return self.cmp(rhs) == 0
359
360 def __ne__(self, rhs):
361 """ Not Equal."""
362 return self.cmp(rhs) != 0
363
364 def __gt__(self, rhs):
365 """ Greater Than."""
366 return self.cmp(rhs) > 0
367
368 def __ge__(self, rhs):
369 """ Greater Than or Equal."""
370 return self.cmp(rhs) >= 0
371
372 def __bool__(self, rhs):
373 """ Convert to bool."""
374 return bool(self.bits)
375
376 def __str__(self):
377 """ Get text representation."""
378 # don't just use self.__float__() in order to work with numbers more
379 # than 53 bits wide
380 retval = "fixed:"
381 bits = self.bits
382 if bits < 0:
383 retval += "-"
384 bits = -bits
385 int_part = bits >> self.fract_width
386 fract_part = bits & ~(-1 << self.fract_width)
387 # round up fract_width to nearest multiple of 4
388 fract_width = (self.fract_width + 3) & ~3
389 fract_part <<= (fract_width - self.fract_width)
390 fract_width_in_hex_digits = fract_width / 4
391 retval += f"{int_part:x}."
392 retval += f"{fract_part:x}".zfill(fract_width_in_hex_digits)
393 return retval
394
395
396 def fixed_sqrt():
397 # FIXME: finish
398 raise NotImplementedError()
399
400
401 class FixedSqrt:
402 # FIXME: finish
403 pass
404
405
406 def fixed_rsqrt():
407 # FIXME: finish
408 raise NotImplementedError()
409
410
411 class FixedRSqrt:
412 # FIXME: finish
413 pass