dc363c5a4bebe56df804193c36bf71fa11db76d7
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
10 from fractions
import Fraction
11 from types
import FunctionType
14 from functools
import cached_property
16 from cached_property
import cached_property
18 # fix broken IDE type detection for cached_property
19 from typing
import TYPE_CHECKING
21 from functools
import cached_property
27 def cache_on_self(func
):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
32 assert isinstance(func
, FunctionType
), \
33 "non-plain methods are not supported"
35 cache_name
= func
.__name
__ + "__cache"
37 def wrapper(self
, *args
, **kwargs
):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
40 if cache
is _NOT_FOUND
:
41 self
.__dict
__[cache_name
] = cache
= {}
42 key
= (args
, *kwargs
.items())
43 retval
= cache
.get(key
, _NOT_FOUND
)
44 if retval
is _NOT_FOUND
:
45 retval
= func(self
, *args
, **kwargs
)
49 wrapper
.__doc
__ = func
.__doc
__
54 class RoundDir(enum
.Enum
):
57 NEAREST_TIES_UP
= enum
.auto()
58 ERROR_IF_INEXACT
= enum
.auto()
61 @dataclass(frozen
=True)
66 def __post_init__(self
):
67 assert isinstance(self
.bits
, int)
68 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value
, FixedPoint
):
76 if isinstance(value
, int):
77 return FixedPoint(value
, 0)
78 if isinstance(value
, str):
80 neg
= value
.startswith("-")
81 if neg
or value
.startswith("+"):
83 if value
.startswith(("0x", "0X")) and "." in value
:
93 raise ValueError("too many `.` in string")
98 if not digit
.isalnum():
99 raise ValueError("invalid hexadecimal digit")
101 bits |
= int("0x" + digit
, base
=16)
103 bits
= int(value
, base
=0)
107 return FixedPoint(bits
, frac_wid
)
109 if isinstance(value
, float):
110 n
, d
= value
.as_integer_ratio()
111 log2_d
= d
.bit_length() - 1
112 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n
, log2_d
)
115 raise TypeError("can't convert type to FixedPoint")
118 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid
, int) and frac_wid
>= 0
122 assert isinstance(round_dir
, RoundDir
)
123 if isinstance(value
, Fraction
):
124 numerator
= value
.numerator
125 denominator
= value
.denominator
127 value
= FixedPoint
.cast(value
)
128 # compute number of bits that should be removed from value
129 del_bits
= value
.frac_wid
- frac_wid
132 if del_bits
< 0: # add bits
133 return FixedPoint(value
.bits
<< -del_bits
,
135 numerator
= value
.bits
136 denominator
= 1 << value
.frac_wid
138 numerator
= -numerator
139 denominator
= -denominator
140 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
141 if round_dir
== RoundDir
.DOWN
:
143 elif round_dir
== RoundDir
.UP
:
146 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
147 if remainder
* 2 >= denominator
:
149 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
151 raise ValueError("inexact conversion")
153 assert False, "unimplemented round_dir"
154 return FixedPoint(bits
, frac_wid
)
156 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
157 """convert to the nearest fixed-point number with `frac_wid`
158 fractional bits, rounding according to `round_dir`."""
159 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
162 # use truediv to get correct result even when bits
163 # and frac_wid are huge
164 return float(self
.bits
/ (1 << self
.frac_wid
))
166 def as_fraction(self
):
167 return Fraction(self
.bits
, 1 << self
.frac_wid
)
170 """compare self with rhs, returning a positive integer if self is
171 greater than rhs, zero if self is equal to rhs, and a negative integer
172 if self is less than rhs."""
173 rhs
= FixedPoint
.cast(rhs
)
174 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
175 lhs
= self
.to_frac_wid(common_frac_wid
)
176 rhs
= rhs
.to_frac_wid(common_frac_wid
)
177 return lhs
.bits
- rhs
.bits
179 def __eq__(self
, rhs
):
180 return self
.cmp(rhs
) == 0
182 def __ne__(self
, rhs
):
183 return self
.cmp(rhs
) != 0
185 def __gt__(self
, rhs
):
186 return self
.cmp(rhs
) > 0
188 def __lt__(self
, rhs
):
189 return self
.cmp(rhs
) < 0
191 def __ge__(self
, rhs
):
192 return self
.cmp(rhs
) >= 0
194 def __le__(self
, rhs
):
195 return self
.cmp(rhs
) <= 0
198 """return the fractional part of `self`.
199 that is `self - math.floor(self)`.
201 fract_mask
= (1 << self
.frac_wid
) - 1
202 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
206 return "-" + str(-self
)
208 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
209 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
210 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
211 return hex(math
.floor(self
)) + "." + frac_str
214 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
216 def __add__(self
, rhs
):
217 rhs
= FixedPoint
.cast(rhs
)
218 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
219 lhs
= self
.to_frac_wid(common_frac_wid
)
220 rhs
= rhs
.to_frac_wid(common_frac_wid
)
221 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
223 def __radd__(self
, lhs
):
225 return self
.__add
__(lhs
)
228 return FixedPoint(-self
.bits
, self
.frac_wid
)
230 def __sub__(self
, rhs
):
231 rhs
= FixedPoint
.cast(rhs
)
232 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
233 lhs
= self
.to_frac_wid(common_frac_wid
)
234 rhs
= rhs
.to_frac_wid(common_frac_wid
)
235 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
237 def __rsub__(self
, lhs
):
239 return -self
.__sub
__(lhs
)
241 def __mul__(self
, rhs
):
242 rhs
= FixedPoint
.cast(rhs
)
243 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
245 def __rmul__(self
, lhs
):
247 return self
.__mul
__(lhs
)
250 return self
.bits
>> self
.frac_wid
254 class GoldschmidtDivState
:
256 """original numerator"""
259 """original denominator"""
262 """numerator -- N_prime[i] in the paper's algorithm 2"""
265 """denominator -- D_prime[i] in the paper's algorithm 2"""
267 f
: "FixedPoint | None" = None
268 """current factor -- F_prime[i] in the paper's algorithm 2"""
270 quotient
: "int | None" = None
273 remainder
: "int | None" = None
274 """final remainder"""
276 n_shift
: "int | None" = None
277 """amount the numerator needs to be left-shifted at the end of the
282 class ParamsNotAccurateEnough(Exception):
283 """raised when the parameters aren't accurate enough to have goldschmidt
287 def _assert_accuracy(condition
, msg
="not accurate enough"):
290 raise ParamsNotAccurateEnough(msg
)
293 @dataclass(frozen
=True, unsafe_hash
=True)
294 class GoldschmidtDivParams
:
295 """parameters for a Goldschmidt division algorithm.
296 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
300 """bit-width of the input divisor and the result.
301 the input numerator is `2 * io_width`-bits wide.
305 """number of bits of additional precision used inside the algorithm."""
308 """the number of address bits used in the lookup-table."""
311 """the number of data bits used in the lookup-table."""
314 """the total number of iterations of the division algorithm's loop"""
316 # tuple to be immutable
317 table
: "tuple[FixedPoint, ...]" = field(init
=False)
318 """the lookup-table"""
320 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False)
321 """the operations needed to perform the goldschmidt division algorithm."""
324 def table_addr_count(self
):
325 """number of distinct addresses in the lookup-table."""
326 # used while computing self.table, so can't just do len(self.table)
327 return 1 << self
.table_addr_bits
329 def table_input_exact_range(self
, addr
):
330 """return the range of inputs as `Fraction`s used for the table entry
331 with address `addr`."""
332 assert isinstance(addr
, int)
333 assert 0 <= addr
< self
.table_addr_count
334 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
335 min_numerator
= (1 << self
.table_addr_bits
) + addr
336 denominator
= 1 << self
.table_addr_bits
337 values_per_table_entry
= 1 << (self
.io_width
- self
.table_addr_bits
)
338 max_numerator
= min_numerator
+ values_per_table_entry
339 min_input
= Fraction(min_numerator
, denominator
)
340 max_input
= Fraction(max_numerator
, denominator
)
341 return min_input
, max_input
343 def table_value_exact_range(self
, addr
):
344 """return the range of values as `Fraction`s used for the table entry
345 with address `addr`."""
346 min_value
, max_value
= self
.table_input_exact_range(addr
)
347 # division swaps min/max
348 return 1 / max_value
, 1 / min_value
350 def table_exact_value(self
, index
):
351 min_value
, max_value
= self
.table_value_exact_range(index
)
355 def __post_init__(self
):
356 # called by the autogenerated __init__
357 assert self
.io_width
>= 1
358 assert self
.extra_precision
>= 0
359 assert self
.table_addr_bits
>= 1
360 assert self
.table_data_bits
>= 1
361 assert self
.iter_count
>= 1
363 for addr
in range(1 << self
.table_addr_bits
):
364 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
365 self
.table_data_bits
,
367 # we have to use object.__setattr__ since frozen=True
368 object.__setattr
__(self
, "table", tuple(table
))
369 object.__setattr
__(self
, "ops", tuple(_goldschmidt_div_ops(self
)))
373 """ find efficient parameters for a goldschmidt division algorithm
374 with `params.io_width == io_width`.
376 assert isinstance(io_width
, int) and io_width
>= 1
377 for extra_precision
in range(io_width
* 2 + 4):
378 for table_addr_bits
in range(1, 7 + 1):
379 table_data_bits
= io_width
+ extra_precision
380 for iter_count
in range(1, 2 * io_width
.bit_length()):
382 return GoldschmidtDivParams(
384 extra_precision
=extra_precision
,
385 table_addr_bits
=table_addr_bits
,
386 table_data_bits
=table_data_bits
,
387 iter_count
=iter_count
)
388 except ParamsNotAccurateEnough
:
390 raise ValueError(f
"can't find working parameters for a goldschmidt "
391 f
"division algorithm with io_width={io_width}")
394 def expanded_width(self
):
395 """the total number of bits of precision used inside the algorithm."""
396 return self
.io_width
+ self
.extra_precision
399 def max_neps(self
, i
):
400 """maximum value of `neps[i]`.
401 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
403 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
404 return Fraction(1, 1 << self
.expanded_width
)
407 def max_deps(self
, i
):
408 """maximum value of `deps[i]`.
409 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
411 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
412 return Fraction(1, 1 << self
.expanded_width
)
415 def max_feps(self
, i
):
416 """maximum value of `feps[i]`.
417 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
419 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
420 # zero, because the computation of `F_prime[i]` in
421 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
426 """minimum and maximum values of `e[0]`
427 (the relative error in `F_prime[-1]`)
431 for addr
in range(self
.table_addr_count
):
432 # `F_prime[-1] = (1 - e[0]) / B`
433 # => `e[0] = 1 - B * F_prime[-1]`
434 min_b
, max_b
= self
.table_input_exact_range(addr
)
435 f_prime_m1
= self
.table
[addr
].as_fraction()
436 assert min_b
>= 0 and f_prime_m1
>= 0, \
437 "only positive quadrant of interval multiplication implemented"
438 min_product
= min_b
* f_prime_m1
439 max_product
= max_b
* f_prime_m1
440 # negation swaps min/max
441 cur_min_e0
= 1 - max_product
442 cur_max_e0
= 1 - min_product
443 min_e0
= min(min_e0
, cur_min_e0
)
444 max_e0
= max(max_e0
, cur_max_e0
)
445 return min_e0
, max_e0
449 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
451 min_e0
, max_e0
= self
.e0_range
456 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
458 min_e0
, max_e0
= self
.e0_range
462 def max_abs_e0(self
):
463 """maximum value of `abs(e[0])`."""
464 return max(abs(self
.min_e0
), abs(self
.max_e0
))
467 def min_abs_e0(self
):
468 """minimum value of `abs(e[0])`."""
473 """maximum value of `n[i]` (the relative error in `N_prime[i]`
474 relative to the previous iteration)
476 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
479 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
480 # `n[0] <= 2 * neps[0] / (1 - e[0])`
482 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
483 "only one quadrant of interval division implemented"
484 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
487 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
488 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
489 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
490 "only one quadrant of interval multiplication implemented"
491 prod
= (1 - self
.max_f(0)) * min_mpd
492 assert self
.max_neps(1) >= 0 and prod
> 0, \
493 "only one quadrant of interval division implemented"
494 retval
= self
.max_neps(1) / prod
497 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
498 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
499 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
500 "only one quadrant of interval division implemented"
501 retval
= self
.max_neps(i
) / min_mpd
503 # we need Fraction to avoid using float by accident
504 # -- it also hints to the IDE to give the correct type
505 return Fraction(retval
)
509 """maximum value of `d[i]` (the relative error in `D_prime[i]`
510 relative to the previous iteration)
512 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
515 # `d[0] = deps[0] / (1 - e[0])`
517 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
518 "only one quadrant of interval division implemented"
519 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
522 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
523 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
524 "only one quadrant of interval multiplication implemented"
525 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
526 assert self
.max_deps(1) >= 0 and divisor
> 0, \
527 "only one quadrant of interval division implemented"
528 retval
= self
.max_deps(1) / divisor
531 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
532 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
533 "only one quadrant of interval division implemented"
534 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
536 # we need Fraction to avoid using float by accident
537 # -- it also hints to the IDE to give the correct type
538 return Fraction(retval
)
542 """maximum value of `f[i]` (the relative error in `F_prime[i]`
543 relative to the previous iteration)
545 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
548 # `f[0] = feps[0] / (1 - delta[0])`
550 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
551 "only one quadrant of interval division implemented"
552 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
556 retval
= self
.max_feps(1)
559 # `f[i] <= max_feps[i]`
560 retval
= self
.max_feps(i
)
562 # we need Fraction to avoid using float by accident
563 # -- it also hints to the IDE to give the correct type
564 return Fraction(retval
)
567 def max_delta(self
, i
):
568 """ maximum value of `delta[i]`.
569 `delta[i]` is defined in Definition 4 of paper.
571 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
573 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
574 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
576 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
577 prev_max_delta
= self
.max_delta(i
- 1)
578 assert prev_max_delta
>= 0
579 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
581 # we need Fraction to avoid using float by accident
582 # -- it also hints to the IDE to give the correct type
583 return Fraction(retval
)
587 """ maximum value of `pi[i]`.
588 `pi[i]` is defined right below Theorem 5 of paper.
590 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
591 # `pi[i] = 1 - (1 - n[i]) * prod`
592 # where `prod` is the product of,
593 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
594 min_prod
= Fraction(0)
596 max_n_j
= self
.max_n(j
)
597 max_d_j
= self
.max_d(j
)
598 assert max_n_j
<= 1 and max_d_j
> -1, \
599 "only one quadrant of interval division implemented"
600 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
601 max_n_i
= self
.max_n(i
)
602 assert max_n_i
<= 1 and min_prod
>= 0, \
603 "only one quadrant of interval multiplication implemented"
604 return 1 - (1 - max_n_i
) * min_prod
607 def max_n_shift(self
):
608 """ maximum value of `state.n_shift`.
610 # input numerator is `2*io_width`-bits
611 max_n
= (1 << (self
.io_width
* 2)) - 1
613 # normalize so 1 <= n < 2
621 class GoldschmidtDivOp(enum
.Enum
):
622 Normalize
= "n, d, n_shift = normalize(n, d)"
623 FEqTableLookup
= "f = table_lookup(d)"
626 FEq2MinusD
= "f = 2 - d"
627 CalcResult
= "result = unnormalize_and_round(n)"
629 def run(self
, params
, state
):
630 assert isinstance(params
, GoldschmidtDivParams
)
631 assert isinstance(state
, GoldschmidtDivState
)
632 expanded_width
= params
.expanded_width
633 table_addr_bits
= params
.table_addr_bits
634 if self
== GoldschmidtDivOp
.Normalize
:
635 # normalize so 1 <= d < 2
636 # can easily be done with count-leading-zeros and left shift
638 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
639 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
642 # normalize so 1 <= n < 2
644 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
646 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
647 # compute initial f by table lookup
649 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
650 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
651 state
.f
= params
.table
[d_m_1
.bits
]
652 elif self
== GoldschmidtDivOp
.MulNByF
:
653 assert state
.f
is not None
654 n
= state
.n
* state
.f
655 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
656 elif self
== GoldschmidtDivOp
.MulDByF
:
657 assert state
.f
is not None
658 d
= state
.d
* state
.f
659 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
660 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
661 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
662 elif self
== GoldschmidtDivOp
.CalcResult
:
663 assert state
.n_shift
is not None
664 # scale to correct value
665 n
= state
.n
* (1 << state
.n_shift
)
667 state
.quotient
= math
.floor(n
)
668 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
669 if state
.remainder
>= state
.orig_d
:
671 state
.remainder
-= state
.orig_d
673 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
676 def _goldschmidt_div_ops(params
):
677 """ Goldschmidt division algorithm.
680 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
681 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
682 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
685 params: GoldschmidtDivParams
686 the parameters for the algorithm
688 yields: GoldschmidtDivOp
689 the operations needed to perform the division.
691 assert isinstance(params
, GoldschmidtDivParams
)
693 # establish assumptions of the paper's error analysis (section 3.1):
695 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
696 yield GoldschmidtDivOp
.Normalize
698 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
699 # the assumption is met by multipliers with > 4-bits precision
700 _assert_accuracy(params
.expanded_width
> 4)
702 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
703 _assert_accuracy(params
.max_abs_e0
+ 3 * params
.max_d(0) / 2
704 + params
.max_f(0) < Fraction(1, 2))
706 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
707 # (B is the denominator)
709 for addr
in range(params
.table_addr_count
):
710 f_prime_m1
= params
.table
[addr
]
711 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
713 yield GoldschmidtDivOp
.FEqTableLookup
715 # we use Setting I (section 4.1 of the paper):
716 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
718 for i
in range(params
.iter_count
):
719 _assert_accuracy(params
.max_f(i
) == 0)
720 n_hat
= max(n_hat
, params
.max_n(i
), params
.max_d(i
))
721 yield GoldschmidtDivOp
.MulNByF
722 if i
!= params
.iter_count
- 1:
723 yield GoldschmidtDivOp
.MulDByF
724 yield GoldschmidtDivOp
.FEq2MinusD
726 # relative approximation error `p(N_prime[i])`:
727 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
728 # `0 <= p(N_prime[i])`
729 # `p(N_prime[i]) <= (2 * i) * n_hat \`
730 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
731 i
= params
.iter_count
- 1 # last used `i`
732 max_rel_error
= (2 * i
) * n_hat
+ \
733 (params
.max_abs_e0
+ 3 * n_hat
/ 2) ** (2 ** i
)
735 min_a_over_b
= Fraction(1, 2)
736 max_a_over_b
= Fraction(2)
737 max_allowed_abs_error
= max_a_over_b
/ (1 << params
.max_n_shift
)
738 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
740 _assert_accuracy(max_rel_error
< max_allowed_rel_error
)
742 yield GoldschmidtDivOp
.CalcResult
745 def goldschmidt_div(n
, d
, params
):
746 """ Goldschmidt division algorithm.
749 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
750 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
751 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
755 numerator. a `2*width`-bit unsigned integer.
756 must be less than `d << width`, otherwise the quotient wouldn't
759 denominator. a `width`-bit unsigned integer. must not be zero.
761 the bit-width of the inputs/outputs. must be a positive integer.
763 returns: tuple[int, int]
764 the quotient and remainder. a tuple of two `width`-bit unsigned
767 assert isinstance(params
, GoldschmidtDivParams
)
768 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
769 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
771 # this whole algorithm is done with fixed-point arithmetic where values
772 # have `width` fractional bits
774 state
= GoldschmidtDivState(
777 n
=FixedPoint(n
, params
.io_width
),
778 d
=FixedPoint(d
, params
.io_width
),
781 for op
in params
.ops
:
782 op
.run(params
, state
)
784 assert state
.quotient
is not None
785 assert state
.remainder
is not None
787 return state
.quotient
, state
.remainder