# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
+from collections import defaultdict
from dataclasses import dataclass, field, fields, replace
import logging
import math
from fractions import Fraction
from types import FunctionType
from functools import lru_cache
+from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.dsl import Module, Elaboratable
+from nmigen.hdl.mem import Memory
+from nmutil.clz import CLZ
try:
from functools import cached_property
frac_wid: int
def __post_init__(self):
+ # called by the autogenerated __init__
assert isinstance(self.bits, int)
assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
"""the total number of bits of precision used inside the algorithm."""
return self.io_width + self.extra_precision
+ @property
+ def n_d_f_int_wid(self):
+ """the number of bits in the integer part of `state.n`, `state.d`, and
+ `state.f` during the main iteration loop.
+ """
+ return 2
+
+ @property
+ def n_d_f_total_wid(self):
+ """the total number of bits (both integer and fraction bits) in
+ `state.n`, `state.d`, and `state.f` during the main iteration loop.
+ """
+ return self.n_d_f_int_wid + self.expanded_width
+
@cache_on_self
def max_neps(self, i):
"""maximum value of `neps[i]`.
return cached_new(params)
+def clz(v, wid):
+ assert isinstance(wid, int)
+ assert isinstance(v, int) and 0 <= v < (1 << wid)
+ return (1 << wid).bit_length() - v.bit_length()
+
+
@enum.unique
class GoldschmidtDivOp(enum.Enum):
Normalize = "n, d, n_shift = normalize(n, d)"
else:
assert False, f"unimplemented GoldschmidtDivOp: {self}"
+ def gen_hdl(self, params, state, sync_rom):
+ # FIXME: finish getting hdl/simulation to work
+ """generate the hdl for this operation.
+
+ arguments:
+ params: GoldschmidtDivParams
+ the goldschmidt division parameters.
+ state: GoldschmidtDivHDLState
+ the input/output state
+ sync_rom: bool
+ true if the rom should be read synchronously rather than
+ combinatorially, incurring an extra clock cycle of latency.
+ """
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(state, GoldschmidtDivHDLState)
+ m = state.m
+ expanded_width = params.expanded_width
+ table_addr_bits = params.table_addr_bits
+ if self == GoldschmidtDivOp.Normalize:
+ # normalize so 1 <= d < 2
+ assert state.d.width == params.io_width
+ assert state.n.width == 2 * params.io_width
+ d_leading_zeros = CLZ(params.io_width)
+ m.submodules.d_leading_zeros = d_leading_zeros
+ m.d.comb += d_leading_zeros.sig_in.eq(state.d)
+ d_shift_out = Signal.like(state.d)
+ m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
+ state.d = Signal(params.n_d_f_total_wid)
+ m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
+ + params.n_d_f_int_wid))
+
+ # normalize so 1 <= n < 2
+ n_leading_zeros = CLZ(2 * params.io_width)
+ m.submodules.n_leading_zeros = n_leading_zeros
+ m.d.comb += n_leading_zeros.sig_in.eq(state.n)
+ n_shift_s_v = (params.io_width + d_leading_zeros.lz
+ - n_leading_zeros.lz)
+ n_shift_s = Signal.like(n_shift_s_v)
+ state.n_shift = Signal(d_leading_zeros.lz.width)
+ m.d.comb += [
+ n_shift_s.eq(n_shift_s_v),
+ state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+ ]
+ n = Signal(params.n_d_f_total_wid)
+ shifted_n = state.n << state.n_shift
+ fixed_shift = params.expanded_width - state.n.width
+ m.d.comb += n.eq(shifted_n << fixed_shift)
+ state.n = n
+ elif self == GoldschmidtDivOp.FEqTableLookup:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ # compute initial f by table lookup
+
+ # extra bit for table entries == 1.0
+ table_width = 1 + params.table_data_bits
+ table = Memory(width=table_width, depth=len(params.table),
+ init=[i.bits for i in params.table])
+ addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+ if sync_rom:
+ table_read = table.read_port()
+ m.d.comb += table_read.addr.eq(addr)
+ state.insert_pipeline_register()
+ else:
+ table_read = table.read_port(domain="comb")
+ m.d.comb += table_read.addr.eq(addr)
+ m.submodules.table_read = table_read
+ state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
+ data_shift = (table_width - params.table_data_bits
+ + params.expanded_width)
+ m.d.comb += state.f.eq(table_read.data << data_shift)
+ elif self == GoldschmidtDivOp.MulNByF:
+ assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+ assert state.f is not None
+ assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+ n = Signal.like(state.n)
+ m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
+ state.n = n
+ elif self == GoldschmidtDivOp.MulDByF:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ assert state.f is not None
+ assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+ d = Signal.like(state.d)
+ m.d.comb += d.eq((state.d * state.f) >> params.expanded_width)
+ state.d = d
+ elif self == GoldschmidtDivOp.FEq2MinusD:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ f = Signal.like(state.d)
+ m.d.comb += f.eq((2 << params.expanded_width) - state.d)
+ state.f = f
+ elif self == GoldschmidtDivOp.CalcResult:
+ assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+ assert state.n_shift is not None
+ # scale to correct value
+ n = state.n * (1 << state.n_shift)
+ q_approx = Signal(params.io_width)
+ # extra bit for if it's bigger than orig_d
+ r_approx = Signal(params.io_width + 1)
+ adjusted_r = Signal(signed(1 + params.io_width))
+ m.d.comb += [
+ q_approx.eq((state.n << state.n_shift)
+ >> params.expanded_width),
+ r_approx.eq(state.orig_n - q_approx * state.orig_d),
+ adjusted_r.eq(r_approx - state.orig_d),
+ ]
+ state.quotient = Signal(params.io_width)
+ state.remainder = Signal(params.io_width)
+
+ with m.If(adjusted_r >= 0):
+ m.d.comb += [
+ state.quotient.eq(q_approx + 1),
+ state.remainder.eq(adjusted_r),
+ ]
+ with m.Else():
+ m.d.comb += [
+ state.quotient.eq(q_approx),
+ state.remainder.eq(r_approx),
+ ]
+ else:
+ assert False, f"unimplemented GoldschmidtDivOp: {self}"
+
@dataclass
class GoldschmidtDivState:
return state.quotient, state.remainder
+@dataclass(eq=False)
+class GoldschmidtDivHDLState:
+ m: Module
+ """The HDL Module"""
+
+ orig_n: Signal
+ """original numerator"""
+
+ orig_d: Signal
+ """original denominator"""
+
+ n: Signal
+ """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+ d: Signal
+ """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+ f: "Signal | None" = None
+ """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+ quotient: "Signal | None" = None
+ """final quotient"""
+
+ remainder: "Signal | None" = None
+ """final remainder"""
+
+ n_shift: "Signal | None" = None
+ """amount the numerator needs to be left-shifted at the end of the
+ algorithm.
+ """
+
+ old_signals: "defaultdict[str, list[Signal]]" = field(repr=False,
+ init=False)
+
+ __signal_name_prefix: "str" = field(default="state_", repr=False,
+ init=False)
+
+ def __post_init__(self):
+ # called by the autogenerated __init__
+ self.old_signals = defaultdict(list)
+
+ def __setattr__(self, name, value):
+ assert isinstance(name, str)
+ if name.startswith("_"):
+ return super().__setattr__(name, value)
+ try:
+ old_signals = self.old_signals[name]
+ except AttributeError:
+ # haven't yet finished __post_init__
+ return super().__setattr__(name, value)
+ assert name != "m" and name != "old_signals", f"can't write to {name}"
+ assert isinstance(value, Signal)
+ value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
+ old_signal = getattr(self, name, None)
+ if old_signal is not None:
+ assert isinstance(old_signal, Signal)
+ old_signals.append(old_signal)
+ return super().__setattr__(name, value)
+
+ def insert_pipeline_register(self):
+ old_prefix = self.__signal_name_prefix
+ try:
+ for field in fields(GoldschmidtDivHDLState):
+ if field.name.startswith("_") or field.name == "m":
+ continue
+ old_sig = getattr(self, field.name, None)
+ if old_sig is None:
+ continue
+ assert isinstance(old_sig, Signal)
+ new_sig = Signal.like(old_sig)
+ setattr(self, field.name, new_sig)
+ self.m.d.sync += new_sig.eq(old_sig)
+ finally:
+ self.__signal_name_prefix = old_prefix
+
+
+class GoldschmidtDivHDL(Elaboratable):
+ # FIXME: finish getting hdl/simulation to work
+ """ Goldschmidt division algorithm.
+
+ based on:
+ Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+ A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+ https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+ attributes:
+ params: GoldschmidtDivParams
+ the goldschmidt division algorithm parameters.
+ pipe_reg_indexes: list[int]
+ the operation indexes where pipeline registers should be inserted.
+ duplicate values mean multiple registers should be inserted for
+ that operation index -- this is useful to allow yosys to spread a
+ multiplication across those multiple pipeline stages.
+ sync_rom: bool
+ true if the rom should be read synchronously rather than
+ combinatorially, incurring an extra clock cycle of latency.
+ n: Signal(unsigned(2 * params.io_width))
+ input numerator. a `2 * params.io_width`-bit unsigned integer.
+ must be less than `d << params.io_width`, otherwise the quotient
+ wouldn't fit in `params.io_width` bits.
+ d: Signal(unsigned(params.io_width))
+ input denominator. a `params.io_width`-bit unsigned integer.
+ must not be zero.
+ q: Signal(unsigned(params.io_width))
+ output quotient. only valid when `n < (d << params.io_width)`.
+ r: Signal(unsigned(params.io_width))
+ output remainder. only valid when `n < (d << params.io_width)`.
+ """
+
+ @property
+ def total_pipeline_registers(self):
+ """the total number of pipeline registers"""
+ return len(self.pipe_reg_indexes) + self.sync_rom
+
+ def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(sync_rom, bool)
+ self.params = params
+ self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
+ self.sync_rom = sync_rom
+ self.n = Signal(unsigned(2 * params.io_width))
+ self.d = Signal(unsigned(params.io_width))
+ self.q = Signal(unsigned(params.io_width))
+ self.r = Signal(unsigned(params.io_width))
+
+ def elaborate(self, platform):
+ m = Module()
+ state = GoldschmidtDivHDLState(
+ m=m,
+ orig_n=self.n,
+ orig_d=self.d,
+ n=self.n,
+ d=self.d)
+
+ # copy and reverse
+ pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
+
+ for op_index, op in enumerate(self.params.ops):
+ while len(pipe_reg_indexes) > 0 \
+ and pipe_reg_indexes[-1] <= op_index:
+ pipe_reg_indexes.pop()
+ state.insert_pipeline_register()
+ op.gen_hdl(self.params, state, self.sync_rom)
+
+ while len(pipe_reg_indexes) > 0:
+ pipe_reg_indexes.pop()
+ state.insert_pipeline_register()
+
+ m.d.comb += self.q.eq(state.quotient)
+ m.d.comb += self.r.eq(state.remainder)
+ return m
+
+
GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
import math
import unittest
from nmutil.formaltest import FHDLTestCase
+from nmutil.sim_util import do_sim
+from nmigen.sim import Tick, Delay
+from nmigen.hdl.ast import Signal
+from nmigen.hdl.dsl import Module
from soc.fu.div.experiment.goldschmidt_div_sqrt import (
- GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div,
- FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+ GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
+ goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
class TestFixedPoint(FHDLTestCase):
with self.subTest(q=hex(q), r=hex(r)):
self.assertEqual((q, r), (expected_q, expected_r))
+ @unittest.skip("hdl/simulation currently broken")
+ def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
+ sync_rom=False):
+ # FIXME: finish getting hdl/simulation to work
+ assert isinstance(io_width, int)
+ params = GoldschmidtDivParams.get(io_width)
+ m = Module()
+ dut = GoldschmidtDivHDL(params, pipe_reg_indexes=pipe_reg_indexes,
+ sync_rom=sync_rom)
+ m.submodules.dut = dut
+ # make sync domain get added
+ m.d.sync += Signal().eq(0)
+
+ def iter_cases():
+ if cases is not None:
+ yield from cases
+ return
+ for d in range(1, 1 << io_width):
+ for n in range(d << io_width):
+ yield (n, d)
+
+ def inputs_proc():
+ yield Tick()
+ for n, d in iter_cases():
+ yield dut.n.eq(n)
+ yield dut.d.eq(d)
+ yield Tick()
+
+ def check_outputs():
+ yield Tick()
+ for _ in range(dut.total_pipeline_registers):
+ yield Tick()
+ for n, d in iter_cases():
+ yield Delay(0.1e-6)
+ expected_q, expected_r = divmod(n, d)
+ with self.subTest(n=hex(n), d=hex(d),
+ expected_q=hex(expected_q),
+ expected_r=hex(expected_r)):
+ q = yield dut.q
+ r = yield dut.r
+ with self.subTest(q=hex(q), r=hex(r)):
+ self.assertEqual((q, r), (expected_q, expected_r))
+ yield Tick()
+
+ with self.subTest(params=str(params)):
+ with do_sim(self, m, (dut.n, dut.d, dut.q, dut.r)) as sim:
+ sim.add_clock(1e-6)
+ sim.add_process(inputs_proc)
+ sim.add_process(check_outputs)
+ sim.run()
+
def test_1_through_4(self):
for io_width in range(1, 4 + 1):
with self.subTest(io_width=io_width):
def test_6(self):
self.tst(6)
+ def test_sim_5(self):
+ self.tst_sim(5)
+
def tst_params(self, io_width):
assert isinstance(io_width, int)
params = GoldschmidtDivParams.get(io_width)