From 383ca8957148922c68e4ea5e0c870318fb4e9ab0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 28 Apr 2022 02:19:01 -0700 Subject: [PATCH] add WIP HDL version of goldschmidt division -- it's currently broken --- .../fu/div/experiment/goldschmidt_div_sqrt.py | 298 ++++++++++++++++++ .../test/test_goldschmidt_div_sqrt.py | 62 +++- 2 files changed, 358 insertions(+), 2 deletions(-) diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index c3837e9a..3801b200 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -4,6 +4,7 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. +from collections import defaultdict from dataclasses import dataclass, field, fields, replace import logging import math @@ -11,6 +12,10 @@ import enum from fractions import Fraction from types import FunctionType from functools import lru_cache +from nmigen.hdl.ast import Signal, unsigned, Mux, signed +from nmigen.hdl.dsl import Module, Elaboratable +from nmigen.hdl.mem import Memory +from nmutil.clz import CLZ try: from functools import cached_property @@ -66,6 +71,7 @@ class FixedPoint: frac_wid: int def __post_init__(self): + # called by the autogenerated __init__ assert isinstance(self.bits, int) assert isinstance(self.frac_wid, int) and self.frac_wid >= 0 @@ -463,6 +469,20 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): """the total number of bits of precision used inside the algorithm.""" return self.io_width + self.extra_precision + @property + def n_d_f_int_wid(self): + """the number of bits in the integer part of `state.n`, `state.d`, and + `state.f` during the main iteration loop. + """ + return 2 + + @property + def n_d_f_total_wid(self): + """the total number of bits (both integer and fraction bits) in + `state.n`, `state.d`, and `state.f` during the main iteration loop. + """ + return self.n_d_f_int_wid + self.expanded_width + @cache_on_self def max_neps(self, i): """maximum value of `neps[i]`. @@ -920,6 +940,12 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): return cached_new(params) +def clz(v, wid): + assert isinstance(wid, int) + assert isinstance(v, int) and 0 <= v < (1 << wid) + return (1 << wid).bit_length() - v.bit_length() + + @enum.unique class GoldschmidtDivOp(enum.Enum): Normalize = "n, d, n_shift = normalize(n, d)" @@ -975,6 +1001,125 @@ class GoldschmidtDivOp(enum.Enum): else: assert False, f"unimplemented GoldschmidtDivOp: {self}" + def gen_hdl(self, params, state, sync_rom): + # FIXME: finish getting hdl/simulation to work + """generate the hdl for this operation. + + arguments: + params: GoldschmidtDivParams + the goldschmidt division parameters. + state: GoldschmidtDivHDLState + the input/output state + sync_rom: bool + true if the rom should be read synchronously rather than + combinatorially, incurring an extra clock cycle of latency. + """ + assert isinstance(params, GoldschmidtDivParams) + assert isinstance(state, GoldschmidtDivHDLState) + m = state.m + expanded_width = params.expanded_width + table_addr_bits = params.table_addr_bits + if self == GoldschmidtDivOp.Normalize: + # normalize so 1 <= d < 2 + assert state.d.width == params.io_width + assert state.n.width == 2 * params.io_width + d_leading_zeros = CLZ(params.io_width) + m.submodules.d_leading_zeros = d_leading_zeros + m.d.comb += d_leading_zeros.sig_in.eq(state.d) + d_shift_out = Signal.like(state.d) + m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz) + state.d = Signal(params.n_d_f_total_wid) + m.d.comb += state.d.eq(d_shift_out << (params.extra_precision + + params.n_d_f_int_wid)) + + # normalize so 1 <= n < 2 + n_leading_zeros = CLZ(2 * params.io_width) + m.submodules.n_leading_zeros = n_leading_zeros + m.d.comb += n_leading_zeros.sig_in.eq(state.n) + n_shift_s_v = (params.io_width + d_leading_zeros.lz + - n_leading_zeros.lz) + n_shift_s = Signal.like(n_shift_s_v) + state.n_shift = Signal(d_leading_zeros.lz.width) + m.d.comb += [ + n_shift_s.eq(n_shift_s_v), + state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)), + ] + n = Signal(params.n_d_f_total_wid) + shifted_n = state.n << state.n_shift + fixed_shift = params.expanded_width - state.n.width + m.d.comb += n.eq(shifted_n << fixed_shift) + state.n = n + elif self == GoldschmidtDivOp.FEqTableLookup: + assert state.d.width == params.n_d_f_total_wid, "invalid d width" + # compute initial f by table lookup + + # extra bit for table entries == 1.0 + table_width = 1 + params.table_data_bits + table = Memory(width=table_width, depth=len(params.table), + init=[i.bits for i in params.table]) + addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:] + if sync_rom: + table_read = table.read_port() + m.d.comb += table_read.addr.eq(addr) + state.insert_pipeline_register() + else: + table_read = table.read_port(domain="comb") + m.d.comb += table_read.addr.eq(addr) + m.submodules.table_read = table_read + state.f = Signal(params.n_d_f_int_wid + params.expanded_width) + data_shift = (table_width - params.table_data_bits + + params.expanded_width) + m.d.comb += state.f.eq(table_read.data << data_shift) + elif self == GoldschmidtDivOp.MulNByF: + assert state.n.width == params.n_d_f_total_wid, "invalid n width" + assert state.f is not None + assert state.f.width == params.n_d_f_total_wid, "invalid f width" + n = Signal.like(state.n) + m.d.comb += n.eq((state.n * state.f) >> params.expanded_width) + state.n = n + elif self == GoldschmidtDivOp.MulDByF: + assert state.d.width == params.n_d_f_total_wid, "invalid d width" + assert state.f is not None + assert state.f.width == params.n_d_f_total_wid, "invalid f width" + d = Signal.like(state.d) + m.d.comb += d.eq((state.d * state.f) >> params.expanded_width) + state.d = d + elif self == GoldschmidtDivOp.FEq2MinusD: + assert state.d.width == params.n_d_f_total_wid, "invalid d width" + f = Signal.like(state.d) + m.d.comb += f.eq((2 << params.expanded_width) - state.d) + state.f = f + elif self == GoldschmidtDivOp.CalcResult: + assert state.n.width == params.n_d_f_total_wid, "invalid n width" + assert state.n_shift is not None + # scale to correct value + n = state.n * (1 << state.n_shift) + q_approx = Signal(params.io_width) + # extra bit for if it's bigger than orig_d + r_approx = Signal(params.io_width + 1) + adjusted_r = Signal(signed(1 + params.io_width)) + m.d.comb += [ + q_approx.eq((state.n << state.n_shift) + >> params.expanded_width), + r_approx.eq(state.orig_n - q_approx * state.orig_d), + adjusted_r.eq(r_approx - state.orig_d), + ] + state.quotient = Signal(params.io_width) + state.remainder = Signal(params.io_width) + + with m.If(adjusted_r >= 0): + m.d.comb += [ + state.quotient.eq(q_approx + 1), + state.remainder.eq(adjusted_r), + ] + with m.Else(): + m.d.comb += [ + state.quotient.eq(q_approx), + state.remainder.eq(r_approx), + ] + else: + assert False, f"unimplemented GoldschmidtDivOp: {self}" + @dataclass class GoldschmidtDivState: @@ -1050,6 +1195,159 @@ def goldschmidt_div(n, d, params): return state.quotient, state.remainder +@dataclass(eq=False) +class GoldschmidtDivHDLState: + m: Module + """The HDL Module""" + + orig_n: Signal + """original numerator""" + + orig_d: Signal + """original denominator""" + + n: Signal + """numerator -- N_prime[i] in the paper's algorithm 2""" + + d: Signal + """denominator -- D_prime[i] in the paper's algorithm 2""" + + f: "Signal | None" = None + """current factor -- F_prime[i] in the paper's algorithm 2""" + + quotient: "Signal | None" = None + """final quotient""" + + remainder: "Signal | None" = None + """final remainder""" + + n_shift: "Signal | None" = None + """amount the numerator needs to be left-shifted at the end of the + algorithm. + """ + + old_signals: "defaultdict[str, list[Signal]]" = field(repr=False, + init=False) + + __signal_name_prefix: "str" = field(default="state_", repr=False, + init=False) + + def __post_init__(self): + # called by the autogenerated __init__ + self.old_signals = defaultdict(list) + + def __setattr__(self, name, value): + assert isinstance(name, str) + if name.startswith("_"): + return super().__setattr__(name, value) + try: + old_signals = self.old_signals[name] + except AttributeError: + # haven't yet finished __post_init__ + return super().__setattr__(name, value) + assert name != "m" and name != "old_signals", f"can't write to {name}" + assert isinstance(value, Signal) + value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}" + old_signal = getattr(self, name, None) + if old_signal is not None: + assert isinstance(old_signal, Signal) + old_signals.append(old_signal) + return super().__setattr__(name, value) + + def insert_pipeline_register(self): + old_prefix = self.__signal_name_prefix + try: + for field in fields(GoldschmidtDivHDLState): + if field.name.startswith("_") or field.name == "m": + continue + old_sig = getattr(self, field.name, None) + if old_sig is None: + continue + assert isinstance(old_sig, Signal) + new_sig = Signal.like(old_sig) + setattr(self, field.name, new_sig) + self.m.d.sync += new_sig.eq(old_sig) + finally: + self.__signal_name_prefix = old_prefix + + +class GoldschmidtDivHDL(Elaboratable): + # FIXME: finish getting hdl/simulation to work + """ Goldschmidt division algorithm. + + based on: + Even, G., Seidel, P. M., & Ferguson, W. E. (2003). + A Parametric Error Analysis of Goldschmidt's Division Algorithm. + https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf + + attributes: + params: GoldschmidtDivParams + the goldschmidt division algorithm parameters. + pipe_reg_indexes: list[int] + the operation indexes where pipeline registers should be inserted. + duplicate values mean multiple registers should be inserted for + that operation index -- this is useful to allow yosys to spread a + multiplication across those multiple pipeline stages. + sync_rom: bool + true if the rom should be read synchronously rather than + combinatorially, incurring an extra clock cycle of latency. + n: Signal(unsigned(2 * params.io_width)) + input numerator. a `2 * params.io_width`-bit unsigned integer. + must be less than `d << params.io_width`, otherwise the quotient + wouldn't fit in `params.io_width` bits. + d: Signal(unsigned(params.io_width)) + input denominator. a `params.io_width`-bit unsigned integer. + must not be zero. + q: Signal(unsigned(params.io_width)) + output quotient. only valid when `n < (d << params.io_width)`. + r: Signal(unsigned(params.io_width)) + output remainder. only valid when `n < (d << params.io_width)`. + """ + + @property + def total_pipeline_registers(self): + """the total number of pipeline registers""" + return len(self.pipe_reg_indexes) + self.sync_rom + + def __init__(self, params, pipe_reg_indexes=(), sync_rom=False): + assert isinstance(params, GoldschmidtDivParams) + assert isinstance(sync_rom, bool) + self.params = params + self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes) + self.sync_rom = sync_rom + self.n = Signal(unsigned(2 * params.io_width)) + self.d = Signal(unsigned(params.io_width)) + self.q = Signal(unsigned(params.io_width)) + self.r = Signal(unsigned(params.io_width)) + + def elaborate(self, platform): + m = Module() + state = GoldschmidtDivHDLState( + m=m, + orig_n=self.n, + orig_d=self.d, + n=self.n, + d=self.d) + + # copy and reverse + pipe_reg_indexes = list(reversed(self.pipe_reg_indexes)) + + for op_index, op in enumerate(self.params.ops): + while len(pipe_reg_indexes) > 0 \ + and pipe_reg_indexes[-1] <= op_index: + pipe_reg_indexes.pop() + state.insert_pipeline_register() + op.gen_hdl(self.params, state, self.sync_rom) + + while len(pipe_reg_indexes) > 0: + pipe_reg_indexes.pop() + state.insert_pipeline_register() + + m.d.comb += self.q.eq(state.quotient) + m.d.comb += self.r.eq(state.remainder) + return m + + GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2 diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index e2984dc1..5b4c89ad 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -7,9 +7,13 @@ import math import unittest from nmutil.formaltest import FHDLTestCase +from nmutil.sim_util import do_sim +from nmigen.sim import Tick, Delay +from nmigen.hdl.ast import Signal +from nmigen.hdl.dsl import Module from soc.fu.div.experiment.goldschmidt_div_sqrt import ( - GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, - FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) + GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough, + goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) class TestFixedPoint(FHDLTestCase): @@ -84,6 +88,57 @@ class TestGoldschmidtDiv(FHDLTestCase): with self.subTest(q=hex(q), r=hex(r)): self.assertEqual((q, r), (expected_q, expected_r)) + @unittest.skip("hdl/simulation currently broken") + def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(), + sync_rom=False): + # FIXME: finish getting hdl/simulation to work + assert isinstance(io_width, int) + params = GoldschmidtDivParams.get(io_width) + m = Module() + dut = GoldschmidtDivHDL(params, pipe_reg_indexes=pipe_reg_indexes, + sync_rom=sync_rom) + m.submodules.dut = dut + # make sync domain get added + m.d.sync += Signal().eq(0) + + def iter_cases(): + if cases is not None: + yield from cases + return + for d in range(1, 1 << io_width): + for n in range(d << io_width): + yield (n, d) + + def inputs_proc(): + yield Tick() + for n, d in iter_cases(): + yield dut.n.eq(n) + yield dut.d.eq(d) + yield Tick() + + def check_outputs(): + yield Tick() + for _ in range(dut.total_pipeline_registers): + yield Tick() + for n, d in iter_cases(): + yield Delay(0.1e-6) + expected_q, expected_r = divmod(n, d) + with self.subTest(n=hex(n), d=hex(d), + expected_q=hex(expected_q), + expected_r=hex(expected_r)): + q = yield dut.q + r = yield dut.r + with self.subTest(q=hex(q), r=hex(r)): + self.assertEqual((q, r), (expected_q, expected_r)) + yield Tick() + + with self.subTest(params=str(params)): + with do_sim(self, m, (dut.n, dut.d, dut.q, dut.r)) as sim: + sim.add_clock(1e-6) + sim.add_process(inputs_proc) + sim.add_process(check_outputs) + sim.run() + def test_1_through_4(self): for io_width in range(1, 4 + 1): with self.subTest(io_width=io_width): @@ -95,6 +150,9 @@ class TestGoldschmidtDiv(FHDLTestCase): def test_6(self): self.tst(6) + def test_sim_5(self): + self.tst_sim(5) + def tst_params(self, io_width): assert isinstance(io_width, int) params = GoldschmidtDivParams.get(io_width) -- 2.30.2