From 2d9fb70cf873f26b77a90cd938e9d656afc4e1d1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 28 Apr 2022 22:40:32 -0700 Subject: [PATCH] HDL works for io_width=5 --- .../fu/div/experiment/goldschmidt_div_sqrt.py | 89 +++++++++++++------ .../test/test_goldschmidt_div_sqrt.py | 56 ++++++++++-- 2 files changed, 115 insertions(+), 30 deletions(-) diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index ea0ddda0..a86fa78d 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -12,7 +12,7 @@ import enum from fractions import Fraction from types import FunctionType from functools import lru_cache -from nmigen.hdl.ast import Signal, unsigned, Mux, signed +from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat from nmigen.hdl.dsl import Module, Elaboratable from nmigen.hdl.mem import Memory from nmutil.clz import CLZ @@ -937,7 +937,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): else: break - return cached_new(params) + retval = cached_new(params) + assert isinstance(retval, GoldschmidtDivParams) + return retval def clz(v, wid): @@ -979,6 +981,8 @@ class GoldschmidtDivOp(enum.Enum): d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN) assert 0 <= d_m_1.bits < (1 << params.table_addr_bits) state.f = params.table[d_m_1.bits] + state.f = state.f.to_frac_wid(expanded_width, + round_dir=RoundDir.DOWN) elif self == GoldschmidtDivOp.MulNByF: assert state.f is not None n = state.n * state.f @@ -1003,7 +1007,6 @@ class GoldschmidtDivOp(enum.Enum): assert False, f"unimplemented GoldschmidtDivOp: {self}" def gen_hdl(self, params, state, sync_rom): - # FIXME: finish getting hdl/simulation to work """generate the hdl for this operation. arguments: @@ -1018,8 +1021,6 @@ class GoldschmidtDivOp(enum.Enum): assert isinstance(params, GoldschmidtDivParams) assert isinstance(state, GoldschmidtDivHDLState) m = state.m - expanded_width = params.expanded_width - table_addr_bits = params.table_addr_bits if self == GoldschmidtDivOp.Normalize: # normalize so 1 <= d < 2 assert state.d.width == params.io_width @@ -1029,27 +1030,41 @@ class GoldschmidtDivOp(enum.Enum): m.d.comb += d_leading_zeros.sig_in.eq(state.d) d_shift_out = Signal.like(state.d) m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz) - state.d = Signal(params.n_d_f_total_wid) - m.d.comb += state.d.eq(d_shift_out << (params.extra_precision - + params.n_d_f_int_wid)) + d = Signal(params.n_d_f_total_wid) + m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width)) + >> state.d.width) # normalize so 1 <= n < 2 n_leading_zeros = CLZ(2 * params.io_width) m.submodules.n_leading_zeros = n_leading_zeros m.d.comb += n_leading_zeros.sig_in.eq(state.n) - n_shift_s_v = (params.io_width + d_leading_zeros.lz + signed_zero = Const(0, signed(1)) # force subtraction to be signed + n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz - n_leading_zeros.lz) n_shift_s = Signal.like(n_shift_s_v) - state.n_shift = Signal(d_leading_zeros.lz.width) + n_shift_n_lz_out = Signal.like(state.n) + n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz) m.d.comb += [ n_shift_s.eq(n_shift_s_v), - state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)), + n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz), + n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz), ] + state.n_shift = Signal(d_leading_zeros.lz.width) n = Signal(params.n_d_f_total_wid) - shifted_n = state.n << state.n_shift - fixed_shift = params.expanded_width - state.n.width - m.d.comb += n.eq(shifted_n << fixed_shift) + with m.If(n_shift_s < 0): + m.d.comb += [ + state.n_shift.eq(0), + n.eq((n_shift_d_lz_out << (1 + params.expanded_width)) + >> state.d.width), + ] + with m.Else(): + m.d.comb += [ + state.n_shift.eq(n_shift_s), + n.eq((n_shift_n_lz_out << (1 + params.expanded_width)) + >> state.n.width), + ] state.n = n + state.d = d elif self == GoldschmidtDivOp.FEqTableLookup: assert state.d.width == params.n_d_f_total_wid, "invalid d width" # compute initial f by table lookup @@ -1058,7 +1073,7 @@ class GoldschmidtDivOp(enum.Enum): table_width = 1 + params.table_data_bits table = Memory(width=table_width, depth=len(params.table), init=[i.bits for i in params.table]) - addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:] + addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:] if sync_rom: table_read = table.read_port() m.d.comb += table_read.addr.eq(addr) @@ -1068,8 +1083,7 @@ class GoldschmidtDivOp(enum.Enum): m.d.comb += table_read.addr.eq(addr) m.submodules.table_read = table_read state.f = Signal(params.n_d_f_int_wid + params.expanded_width) - data_shift = (table_width - params.table_data_bits - + params.expanded_width) + data_shift = params.expanded_width - params.table_data_bits m.d.comb += state.f.eq(table_read.data << data_shift) elif self == GoldschmidtDivOp.MulNByF: assert state.n.width == params.n_d_f_total_wid, "invalid n width" @@ -1150,8 +1164,20 @@ class GoldschmidtDivState: algorithm. """ + def __repr__(self): + fields_str = [] + for field in fields(GoldschmidtDivState): + value = getattr(self, field.name) + if value is None: + continue + if isinstance(value, int) and field.name != "n_shift": + fields_str.append(f"{field.name}={hex(value)}") + else: + fields_str.append(f"{field.name}={value!r}") + return f"GoldschmidtDivState({', '.join(fields_str)})" + -def goldschmidt_div(n, d, params): +def goldschmidt_div(n, d, params, trace=lambda state: None): """ Goldschmidt division algorithm. based on: @@ -1168,6 +1194,9 @@ def goldschmidt_div(n, d, params): denominator. a `width`-bit unsigned integer. must not be zero. width: int the bit-width of the inputs/outputs. must be a positive integer. + trace: Function[[GoldschmidtDivState], None] + called with the initial state and the state after executing each + operation in `params.ops`. returns: tuple[int, int] the quotient and remainder. a tuple of two `width`-bit unsigned @@ -1187,8 +1216,10 @@ def goldschmidt_div(n, d, params): d=FixedPoint(d, params.io_width), ) + trace(state) for op in params.ops: op.run(params, state) + trace(state) assert state.quotient is not None assert state.remainder is not None @@ -1273,7 +1304,6 @@ class GoldschmidtDivHDLState: class GoldschmidtDivHDL(Elaboratable): - # FIXME: finish getting hdl/simulation to work """ Goldschmidt division algorithm. based on: @@ -1303,6 +1333,9 @@ class GoldschmidtDivHDL(Elaboratable): output quotient. only valid when `n < (d << params.io_width)`. r: Signal(unsigned(params.io_width)) output remainder. only valid when `n < (d << params.io_width)`. + trace: list[GoldschmidtDivHDLState] + list of the initial state and the state after executing each + operation in `params.ops`. """ @property @@ -1321,15 +1354,16 @@ class GoldschmidtDivHDL(Elaboratable): self.q = Signal(unsigned(params.io_width)) self.r = Signal(unsigned(params.io_width)) - def elaborate(self, platform): - m = Module() + # in constructor so we get trace without needing to call elaborate state = GoldschmidtDivHDLState( - m=m, + m=Module(), orig_n=self.n, orig_d=self.d, n=self.n, d=self.d) + self.trace = [replace(state)] + # copy and reverse pipe_reg_indexes = list(reversed(self.pipe_reg_indexes)) @@ -1339,14 +1373,19 @@ class GoldschmidtDivHDL(Elaboratable): pipe_reg_indexes.pop() state.insert_pipeline_register() op.gen_hdl(self.params, state, self.sync_rom) + self.trace.append(replace(state)) while len(pipe_reg_indexes) > 0: pipe_reg_indexes.pop() state.insert_pipeline_register() - m.d.comb += self.q.eq(state.quotient) - m.d.comb += self.r.eq(state.remainder) - return m + state.m.d.comb += [ + self.q.eq(state.quotient), + self.r.eq(state.remainder), + ] + + def elaborate(self, platform): + return self.trace[0].m GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2 diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index 5b4c89ad..b4d4fb85 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -4,6 +4,7 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. +from dataclasses import fields, replace import math import unittest from nmutil.formaltest import FHDLTestCase @@ -12,8 +13,9 @@ from nmigen.sim import Tick, Delay from nmigen.hdl.ast import Signal from nmigen.hdl.dsl import Module from soc.fu.div.experiment.goldschmidt_div_sqrt import ( - GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough, - goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) + GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams, + GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div, + FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) class TestFixedPoint(FHDLTestCase): @@ -88,10 +90,8 @@ class TestGoldschmidtDiv(FHDLTestCase): with self.subTest(q=hex(q), r=hex(r)): self.assertEqual((q, r), (expected_q, expected_r)) - @unittest.skip("hdl/simulation currently broken") def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(), sync_rom=False): - # FIXME: finish getting hdl/simulation to work assert isinstance(io_width, int) params = GoldschmidtDivParams.get(io_width) m = Module() @@ -103,7 +103,12 @@ class TestGoldschmidtDiv(FHDLTestCase): def iter_cases(): if cases is not None: - yield from cases + for n, d in cases: + assert isinstance(d, int) \ + and 0 < d < (1 << params.io_width), "invalid case" + assert isinstance(n, int) \ + and 0 <= n < (d << params.io_width), "invalid case" + yield (n, d) return for d in range(1, 1 << io_width): for n in range(d << io_width): @@ -116,6 +121,45 @@ class TestGoldschmidtDiv(FHDLTestCase): yield dut.d.eq(d) yield Tick() + def check_interals(n, d): + # check internals only if dut is completely combinatorial + # so we don't have to figure out how to read values in + # previous clock cycles + if dut.total_pipeline_registers != 0: + return + ref_trace = [] + + def ref_trace_fn(state): + assert isinstance(state, GoldschmidtDivState) + ref_trace.append((replace(state))) + goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn) + self.assertEqual(len(dut.trace), len(ref_trace)) + for index, state in enumerate(dut.trace): + ref_state = ref_trace[index] + last_op = None if index == 0 else params.ops[index - 1] + with self.subTest(index=index, state=repr(state), + ref_state=repr(ref_state), + last_op=str(last_op)): + for field in fields(GoldschmidtDivHDLState): + sig = getattr(state, field.name) + if not isinstance(sig, Signal): + continue + ref_value = getattr(ref_state, field.name) + ref_value_str = repr(ref_value) + if isinstance(ref_value, int): + ref_value_str = hex(ref_value) + value = yield sig + with self.subTest(field_name=field.name, + sig=repr(sig), + sig_shape=repr(sig.shape()), + value=hex(value), + ref_value=ref_value_str): + if isinstance(ref_value, int): + self.assertEqual(value, ref_value) + else: + assert isinstance(ref_value, FixedPoint) + self.assertEqual(value, ref_value.bits) + def check_outputs(): yield Tick() for _ in range(dut.total_pipeline_registers): @@ -130,6 +174,8 @@ class TestGoldschmidtDiv(FHDLTestCase): r = yield dut.r with self.subTest(q=hex(q), r=hex(r)): self.assertEqual((q, r), (expected_q, expected_r)) + yield from check_interals(n, d) + yield Tick() with self.subTest(params=str(params)): -- 2.30.2