from fractions import Fraction
from types import FunctionType
from functools import lru_cache
-from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
from nmigen.hdl.dsl import Module, Elaboratable
from nmigen.hdl.mem import Memory
from nmutil.clz import CLZ
else:
break
- return cached_new(params)
+ retval = cached_new(params)
+ assert isinstance(retval, GoldschmidtDivParams)
+ return retval
def clz(v, wid):
d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
state.f = params.table[d_m_1.bits]
+ state.f = state.f.to_frac_wid(expanded_width,
+ round_dir=RoundDir.DOWN)
elif self == GoldschmidtDivOp.MulNByF:
assert state.f is not None
n = state.n * state.f
assert False, f"unimplemented GoldschmidtDivOp: {self}"
def gen_hdl(self, params, state, sync_rom):
- # FIXME: finish getting hdl/simulation to work
"""generate the hdl for this operation.
arguments:
assert isinstance(params, GoldschmidtDivParams)
assert isinstance(state, GoldschmidtDivHDLState)
m = state.m
- expanded_width = params.expanded_width
- table_addr_bits = params.table_addr_bits
if self == GoldschmidtDivOp.Normalize:
# normalize so 1 <= d < 2
assert state.d.width == params.io_width
m.d.comb += d_leading_zeros.sig_in.eq(state.d)
d_shift_out = Signal.like(state.d)
m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
- state.d = Signal(params.n_d_f_total_wid)
- m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
- + params.n_d_f_int_wid))
+ d = Signal(params.n_d_f_total_wid)
+ m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
+ >> state.d.width)
# normalize so 1 <= n < 2
n_leading_zeros = CLZ(2 * params.io_width)
m.submodules.n_leading_zeros = n_leading_zeros
m.d.comb += n_leading_zeros.sig_in.eq(state.n)
- n_shift_s_v = (params.io_width + d_leading_zeros.lz
+ signed_zero = Const(0, signed(1)) # force subtraction to be signed
+ n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
- n_leading_zeros.lz)
n_shift_s = Signal.like(n_shift_s_v)
- state.n_shift = Signal(d_leading_zeros.lz.width)
+ n_shift_n_lz_out = Signal.like(state.n)
+ n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
m.d.comb += [
n_shift_s.eq(n_shift_s_v),
- state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+ n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
+ n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
]
+ state.n_shift = Signal(d_leading_zeros.lz.width)
n = Signal(params.n_d_f_total_wid)
- shifted_n = state.n << state.n_shift
- fixed_shift = params.expanded_width - state.n.width
- m.d.comb += n.eq(shifted_n << fixed_shift)
+ with m.If(n_shift_s < 0):
+ m.d.comb += [
+ state.n_shift.eq(0),
+ n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
+ >> state.d.width),
+ ]
+ with m.Else():
+ m.d.comb += [
+ state.n_shift.eq(n_shift_s),
+ n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
+ >> state.n.width),
+ ]
state.n = n
+ state.d = d
elif self == GoldschmidtDivOp.FEqTableLookup:
assert state.d.width == params.n_d_f_total_wid, "invalid d width"
# compute initial f by table lookup
table_width = 1 + params.table_data_bits
table = Memory(width=table_width, depth=len(params.table),
init=[i.bits for i in params.table])
- addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+ addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
if sync_rom:
table_read = table.read_port()
m.d.comb += table_read.addr.eq(addr)
m.d.comb += table_read.addr.eq(addr)
m.submodules.table_read = table_read
state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
- data_shift = (table_width - params.table_data_bits
- + params.expanded_width)
+ data_shift = params.expanded_width - params.table_data_bits
m.d.comb += state.f.eq(table_read.data << data_shift)
elif self == GoldschmidtDivOp.MulNByF:
assert state.n.width == params.n_d_f_total_wid, "invalid n width"
algorithm.
"""
+ def __repr__(self):
+ fields_str = []
+ for field in fields(GoldschmidtDivState):
+ value = getattr(self, field.name)
+ if value is None:
+ continue
+ if isinstance(value, int) and field.name != "n_shift":
+ fields_str.append(f"{field.name}={hex(value)}")
+ else:
+ fields_str.append(f"{field.name}={value!r}")
+ return f"GoldschmidtDivState({', '.join(fields_str)})"
+
-def goldschmidt_div(n, d, params):
+def goldschmidt_div(n, d, params, trace=lambda state: None):
""" Goldschmidt division algorithm.
based on:
denominator. a `width`-bit unsigned integer. must not be zero.
width: int
the bit-width of the inputs/outputs. must be a positive integer.
+ trace: Function[[GoldschmidtDivState], None]
+ called with the initial state and the state after executing each
+ operation in `params.ops`.
returns: tuple[int, int]
the quotient and remainder. a tuple of two `width`-bit unsigned
d=FixedPoint(d, params.io_width),
)
+ trace(state)
for op in params.ops:
op.run(params, state)
+ trace(state)
assert state.quotient is not None
assert state.remainder is not None
class GoldschmidtDivHDL(Elaboratable):
- # FIXME: finish getting hdl/simulation to work
""" Goldschmidt division algorithm.
based on:
output quotient. only valid when `n < (d << params.io_width)`.
r: Signal(unsigned(params.io_width))
output remainder. only valid when `n < (d << params.io_width)`.
+ trace: list[GoldschmidtDivHDLState]
+ list of the initial state and the state after executing each
+ operation in `params.ops`.
"""
@property
self.q = Signal(unsigned(params.io_width))
self.r = Signal(unsigned(params.io_width))
- def elaborate(self, platform):
- m = Module()
+ # in constructor so we get trace without needing to call elaborate
state = GoldschmidtDivHDLState(
- m=m,
+ m=Module(),
orig_n=self.n,
orig_d=self.d,
n=self.n,
d=self.d)
+ self.trace = [replace(state)]
+
# copy and reverse
pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
pipe_reg_indexes.pop()
state.insert_pipeline_register()
op.gen_hdl(self.params, state, self.sync_rom)
+ self.trace.append(replace(state))
while len(pipe_reg_indexes) > 0:
pipe_reg_indexes.pop()
state.insert_pipeline_register()
- m.d.comb += self.q.eq(state.quotient)
- m.d.comb += self.r.eq(state.remainder)
- return m
+ state.m.d.comb += [
+ self.q.eq(state.quotient),
+ self.r.eq(state.remainder),
+ ]
+
+ def elaborate(self, platform):
+ return self.trace[0].m
GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
+from dataclasses import fields, replace
import math
import unittest
from nmutil.formaltest import FHDLTestCase
from nmigen.hdl.ast import Signal
from nmigen.hdl.dsl import Module
from soc.fu.div.experiment.goldschmidt_div_sqrt import (
- GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
- goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+ GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+ GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
+ FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
class TestFixedPoint(FHDLTestCase):
with self.subTest(q=hex(q), r=hex(r)):
self.assertEqual((q, r), (expected_q, expected_r))
- @unittest.skip("hdl/simulation currently broken")
def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
sync_rom=False):
- # FIXME: finish getting hdl/simulation to work
assert isinstance(io_width, int)
params = GoldschmidtDivParams.get(io_width)
m = Module()
def iter_cases():
if cases is not None:
- yield from cases
+ for n, d in cases:
+ assert isinstance(d, int) \
+ and 0 < d < (1 << params.io_width), "invalid case"
+ assert isinstance(n, int) \
+ and 0 <= n < (d << params.io_width), "invalid case"
+ yield (n, d)
return
for d in range(1, 1 << io_width):
for n in range(d << io_width):
yield dut.d.eq(d)
yield Tick()
+ def check_interals(n, d):
+ # check internals only if dut is completely combinatorial
+ # so we don't have to figure out how to read values in
+ # previous clock cycles
+ if dut.total_pipeline_registers != 0:
+ return
+ ref_trace = []
+
+ def ref_trace_fn(state):
+ assert isinstance(state, GoldschmidtDivState)
+ ref_trace.append((replace(state)))
+ goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
+ self.assertEqual(len(dut.trace), len(ref_trace))
+ for index, state in enumerate(dut.trace):
+ ref_state = ref_trace[index]
+ last_op = None if index == 0 else params.ops[index - 1]
+ with self.subTest(index=index, state=repr(state),
+ ref_state=repr(ref_state),
+ last_op=str(last_op)):
+ for field in fields(GoldschmidtDivHDLState):
+ sig = getattr(state, field.name)
+ if not isinstance(sig, Signal):
+ continue
+ ref_value = getattr(ref_state, field.name)
+ ref_value_str = repr(ref_value)
+ if isinstance(ref_value, int):
+ ref_value_str = hex(ref_value)
+ value = yield sig
+ with self.subTest(field_name=field.name,
+ sig=repr(sig),
+ sig_shape=repr(sig.shape()),
+ value=hex(value),
+ ref_value=ref_value_str):
+ if isinstance(ref_value, int):
+ self.assertEqual(value, ref_value)
+ else:
+ assert isinstance(ref_value, FixedPoint)
+ self.assertEqual(value, ref_value.bits)
+
def check_outputs():
yield Tick()
for _ in range(dut.total_pipeline_registers):
r = yield dut.r
with self.subTest(q=hex(q), r=hex(r)):
self.assertEqual((q, r), (expected_q, expected_r))
+ yield from check_interals(n, d)
+
yield Tick()
with self.subTest(params=str(params)):