# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import math
import enum
+from fractions import Fraction
@enum.unique
def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
"""convert `value` to the nearest fixed-point number with `frac_wid`
fractional bits, rounding according to `round_dir`."""
- value = FixedPoint.cast(value)
assert isinstance(frac_wid, int) and frac_wid >= 0
assert isinstance(round_dir, RoundDir)
- # compute number of bits that should be removed from value
- del_bits = value.frac_wid - frac_wid
- if del_bits == 0:
- return value
- if del_bits < 0: # add bits
- return FixedPoint(value.bits << -del_bits,
- frac_wid)
+ if isinstance(value, Fraction):
+ numerator = value.numerator
+ denominator = value.denominator
+ else:
+ value = FixedPoint.cast(value)
+ # compute number of bits that should be removed from value
+ del_bits = value.frac_wid - frac_wid
+ if del_bits == 0:
+ return value
+ if del_bits < 0: # add bits
+ return FixedPoint(value.bits << -del_bits,
+ frac_wid)
+ numerator = value.bits
+ denominator = 1 << value.frac_wid
+ if denominator < 0:
+ numerator = -numerator
+ denominator = -denominator
+ bits, remainder = divmod(numerator << frac_wid, denominator)
if round_dir == RoundDir.DOWN:
- bits = value.bits >> del_bits
+ pass
elif round_dir == RoundDir.UP:
- bits = -((-value.bits) >> del_bits)
+ if remainder != 0:
+ bits += 1
elif round_dir == RoundDir.NEAREST_TIES_UP:
- bits = value.bits >> (del_bits - 1)
- bits += 1
- bits >>= 1
+ if remainder * 2 >= denominator:
+ bits += 1
elif round_dir == RoundDir.ERROR_IF_INEXACT:
- bits = value.bits >> del_bits
- if bits << del_bits != value.bits:
+ if remainder != 0:
raise ValueError("inexact conversion")
else:
assert False, "unimplemented round_dir"
return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
def __float__(self):
- return self.bits * 2.0 ** -self.frac_wid
+ # use truediv to get correct result even when bits
+ # and frac_wid are huge
+ return float(self.bits / (1 << self.frac_wid))
+
+ def as_fraction(self):
+ return Fraction(self.bits, 1 << self.frac_wid)
def cmp(self, rhs):
"""compare self with rhs, returning a positive integer if self is
rhs = rhs.to_frac_wid(common_frac_wid)
return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
+ def __radd__(self, lhs):
+ # symmetric
+ return self.__add__(lhs)
+
def __neg__(self):
return FixedPoint(-self.bits, self.frac_wid)
rhs = rhs.to_frac_wid(common_frac_wid)
return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
+ def __rsub__(self, lhs):
+ # a - b == -(b - a)
+ return -self.__sub__(lhs)
+
def __mul__(self, rhs):
rhs = FixedPoint.cast(rhs)
return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
+ def __rmul__(self, lhs):
+ # symmetric
+ return self.__mul__(lhs)
+
def __floor__(self):
return self.bits >> self.frac_wid
-def goldschmidt_div(n, d, width):
+@dataclass
+class GoldschmidtDivState:
+ n: FixedPoint
+ """numerator -- N_prime[i] in the paper's algorithm 2"""
+ d: FixedPoint
+ """denominator -- D_prime[i] in the paper's algorithm 2"""
+ f: "FixedPoint | None" = None
+ """current factor -- F_prime[i] in the paper's algorithm 2"""
+ result: "int | None" = None
+ """final result"""
+ n_shift: "int | None" = None
+ """amount the numerator needs to be left-shifted at the end of the
+ algorithm.
+ """
+
+
+class ParamsNotAccurateEnough(Exception):
+ """raised when the parameters aren't accurate enough to have goldschmidt
+ division work."""
+
+
+def _assert_accuracy(condition, msg="not accurate enough"):
+ if condition:
+ return
+ raise ParamsNotAccurateEnough(msg)
+
+
+@dataclass(frozen=True, unsafe_hash=True)
+class GoldschmidtDivParams:
+ """parameters for a Goldschmidt division algorithm.
+ Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+ """
+ io_width: int
+ """bit-width of the input divisor and the result.
+ the input numerator is `2 * io_width`-bits wide.
+ """
+ extra_precision: int
+ """number of bits of additional precision used inside the algorithm."""
+ table_addr_bits: int
+ """the number of address bits used in the lookup-table."""
+ table_data_bits: int
+ """the number of data bits used in the lookup-table."""
+ # tuple to be immutable
+ table: "tuple[FixedPoint, ...]" = field(init=False)
+ """the lookup-table"""
+ ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
+ """the operations needed to perform the goldschmidt division algorithm."""
+
+ @property
+ def table_addr_count(self):
+ """number of distinct addresses in the lookup-table."""
+ # used while computing self.table, so can't just do len(self.table)
+ return 1 << self.table_addr_bits
+
+ def table_input_exact_range(self, addr):
+ """return the range of inputs as `Fraction`s used for the table entry
+ with address `addr`."""
+ assert isinstance(addr, int)
+ assert 0 <= addr < self.table_addr_count
+ assert self.io_width >= self.table_addr_bits
+ min_numerator = (1 << self.table_addr_bits) + addr
+ denominator = 1 << self.table_addr_bits
+ values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
+ max_numerator = min_numerator + values_per_table_entry
+ min_input = Fraction(min_numerator, denominator)
+ max_input = Fraction(max_numerator, denominator)
+ return min_input, max_input
+
+ def table_value_exact_range(self, addr):
+ """return the range of values as `Fraction`s used for the table entry
+ with address `addr`."""
+ min_value, max_value = self.table_input_exact_range(addr)
+ # division swaps min/max
+ return 1 / max_value, 1 / min_value
+
+ def table_exact_value(self, index):
+ min_value, max_value = self.table_value_exact_range(index)
+ # we round down
+ return min_value
+
+ def __post_init__(self):
+ # called by the autogenerated __init__
+ assert self.io_width >= 1
+ assert self.extra_precision >= 0
+ assert self.table_addr_bits >= 1
+ assert self.table_data_bits >= 1
+ table = []
+ for addr in range(1 << self.table_addr_bits):
+ table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
+ self.table_data_bits,
+ RoundDir.DOWN))
+ # we have to use object.__setattr__ since frozen=True
+ object.__setattr__(self, "table", tuple(table))
+ object.__setattr__(self, "ops", tuple(_goldschmidt_div_ops(self)))
+
+ @staticmethod
+ def get(io_width):
+ """ find efficient parameters for a goldschmidt division algorithm
+ with `params.io_width == io_width`.
+ """
+ assert isinstance(io_width, int) and io_width >= 1
+ for extra_precision in range(io_width * 2):
+ for table_addr_bits in range(3, 7 + 1):
+ table_data_bits = io_width + extra_precision
+ try:
+ return GoldschmidtDivParams(
+ io_width=io_width,
+ extra_precision=extra_precision,
+ table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits)
+ except ParamsNotAccurateEnough:
+ pass
+ raise ValueError(f"can't find working parameters for a goldschmidt "
+ f"division algorithm with io_width={io_width}")
+
+ @property
+ def expanded_width(self):
+ """the total number of bits of precision used inside the algorithm."""
+ return self.io_width + self.extra_precision
+
+
+@enum.unique
+class GoldschmidtDivOp(enum.Enum):
+ Normalize = "n, d, n_shift = normalize(n, d)"
+ FEqTableLookup = "f = table_lookup(d)"
+ MulNByF = "n *= f"
+ MulDByF = "d *= f"
+ FEq2MinusD = "f = 2 - d"
+ CalcResult = "result = unnormalize_and_round(n)"
+
+ def run(self, params, state):
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(state, GoldschmidtDivState)
+ expanded_width = params.expanded_width
+ table_addr_bits = params.table_addr_bits
+ if self == GoldschmidtDivOp.Normalize:
+ # normalize so 1 <= d < 2
+ # can easily be done with count-leading-zeros and left shift
+ while state.d < 1:
+ state.n = (state.n * 2).to_frac_wid(expanded_width)
+ state.d = (state.d * 2).to_frac_wid(expanded_width)
+
+ state.n_shift = 0
+ # normalize so 1 <= n < 2
+ while state.n >= 2:
+ state.n = (state.n * 0.5).to_frac_wid(expanded_width)
+ state.n_shift += 1
+ elif self == GoldschmidtDivOp.FEqTableLookup:
+ # compute initial f by table lookup
+ d_m_1 = state.d - 1
+ d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
+ assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
+ state.f = params.table[d_m_1.bits]
+ elif self == GoldschmidtDivOp.MulNByF:
+ assert state.f is not None
+ n = state.n * state.f
+ state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
+ elif self == GoldschmidtDivOp.MulDByF:
+ assert state.f is not None
+ d = state.d * state.f
+ state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
+ elif self == GoldschmidtDivOp.FEq2MinusD:
+ state.f = (2 - state.d).to_frac_wid(expanded_width)
+ elif self == GoldschmidtDivOp.CalcResult:
+ assert state.n_shift is not None
+ # scale to correct value
+ n = state.n * (1 << state.n_shift)
+
+ # avoid incorrectly rounding down
+ n = n.to_frac_wid(params.io_width, round_dir=RoundDir.UP)
+ state.result = math.floor(n)
+ else:
+ assert False, f"unimplemented GoldschmidtDivOp: {self}"
+
+
+def _goldschmidt_div_ops(params):
+ """ Goldschmidt division algorithm.
+
+ based on:
+ Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+ A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+ https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+ arguments:
+ params: GoldschmidtDivParams
+ the parameters for the algorithm
+
+ yields: GoldschmidtDivOp
+ the operations needed to perform the division.
+ """
+ assert isinstance(params, GoldschmidtDivParams)
+
+ # establish assumptions of the paper's error analysis (section 3.1):
+
+ # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
+ yield GoldschmidtDivOp.Normalize
+
+ # 2. ensure all relative errors from directed rounding are <= 1 / 4.
+ # the assumption is met by multipliers with > 4-bits precision
+ _assert_accuracy(params.expanded_width > 4)
+
+ # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
+
+ # maximum `abs(e[0])`
+ max_abs_e0 = 0
+ # maximum `d[0]`
+ max_d0 = 0
+ # `f[i] = 0` for all `i`
+ fi = 0
+ for addr in range(params.table_addr_count):
+ # `F_prime[-1] = (1 - e[0]) / B`
+ # => `e[0] = 1 - B * F_prime[-1]`
+ min_b, max_b = params.table_input_exact_range(addr)
+ f_prime_m1 = params.table[addr].as_fraction()
+ assert min_b >= 0 and f_prime_m1 >= 0, \
+ "only positive quadrant of interval multiplication implemented"
+ min_product = min_b * f_prime_m1
+ max_product = max_b * f_prime_m1
+ # negation swaps min/max
+ min_e0 = 1 - max_product
+ max_e0 = 1 - min_product
+ max_abs_e0 = max(max_abs_e0, abs(min_e0), abs(max_e0))
+
+ # `D_prime[0] = (1 + d[0]) * B * F_prime[-1]`
+ # `D_prime[0] = abs_round_err + B * F_prime[-1]`
+ # => `d[0] = abs_round_err / (B * F_prime[-1])`
+ max_abs_round_err = Fraction(1, 1 << params.expanded_width)
+ assert min_product > 0 and max_abs_round_err >= 0, \
+ "only positive quadrant of interval division implemented"
+ # division swaps divisor's min/max
+ max_d0 = max(max_d0, max_abs_round_err / min_product)
+
+ _assert_accuracy(max_abs_e0 + 3 * max_d0 / 2 + fi < Fraction(1, 2))
+
+ # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
+ # (B is the denominator)
+
+ for addr in range(params.table_addr_count):
+ f_prime_m1 = params.table[addr]
+ _assert_accuracy(0.5 <= f_prime_m1 <= 1)
+
+ yield GoldschmidtDivOp.FEqTableLookup
+
+ # we use Setting I (section 4.1 of the paper)
+
+ min_bits_of_precision = 1
+ # FIXME: calculate error and check if it's small enough
+ while min_bits_of_precision < params.io_width * 2:
+ yield GoldschmidtDivOp.MulNByF
+ yield GoldschmidtDivOp.MulDByF
+ yield GoldschmidtDivOp.FEq2MinusD
+
+ min_bits_of_precision *= 2
+
+ yield GoldschmidtDivOp.CalcResult
+
+
+def goldschmidt_div(n, d, params):
""" Goldschmidt division algorithm.
based on:
returns: int
the quotient. a `width`-bit unsigned integer.
"""
- assert isinstance(width, int) and width >= 1
- assert isinstance(d, int) and 0 < d < (1 << width)
- assert isinstance(n, int) and 0 <= n < (d << width)
-
- # FIXME: calculate best values for extra_precision, table_addr_bits, and
- # table_data_bits -- these are wrong
- extra_precision = width + 3
- table_addr_bits = 4
- table_data_bits = 8
-
- width += extra_precision
-
- table = []
- for i in range(1 << table_addr_bits):
- value = 1 / (1 + i * 2 ** -table_addr_bits)
- table.append(FixedPoint.with_frac_wid(value, table_data_bits,
- RoundDir.DOWN))
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(d, int) and 0 < d < (1 << params.io_width)
+ assert isinstance(n, int) and 0 <= n < (d << params.io_width)
# this whole algorithm is done with fixed-point arithmetic where values
# have `width` fractional bits
- n = FixedPoint(n, width)
- d = FixedPoint(d, width)
+ state = GoldschmidtDivState(
+ n=FixedPoint(n, params.io_width),
+ d=FixedPoint(d, params.io_width),
+ )
- # normalize so 1 <= d < 2
- # can easily be done with count-leading-zeros and left shift
- while d < 1:
- n = (n * 2).to_frac_wid(width)
- d = (d * 2).to_frac_wid(width)
-
- n_shift = 0
- # normalize so 1 <= n < 2
- while n >= 2:
- n = (n * 0.5).to_frac_wid(width)
- n_shift += 1
-
- # compute initial f by table lookup
- f = table[(d - 1).to_frac_wid(table_addr_bits, RoundDir.DOWN).bits]
-
- min_bits_of_precision = 1
- while min_bits_of_precision < width * 2:
- # multiply both n and d by f
- n *= f
- d *= f
- n = n.to_frac_wid(width, round_dir=RoundDir.DOWN)
- d = d.to_frac_wid(width, round_dir=RoundDir.UP)
-
- # slightly less than 2 to make the computation just a bitwise not
- nearly_two = FixedPoint.with_frac_wid(2, width)
- nearly_two = FixedPoint(nearly_two.bits - 1, width)
- f = (nearly_two - d).to_frac_wid(width)
-
- min_bits_of_precision *= 2
+ for op in params.ops:
+ op.run(params, state)
- # scale to correct value
- n *= 1 << n_shift
+ assert state.result is not None
- # avoid incorrectly rounding down
- n = n.to_frac_wid(width - extra_precision, round_dir=RoundDir.UP)
- return math.floor(n)
+ return state.result