https://bugs.libre-soc.org/show_bug.cgi?id=784
"""
+from dataclasses import dataclass, field, fields
from nmigen.hdl.ir import Elaboratable
-from nmigen.hdl.ast import Signal
+from nmigen.hdl.ast import Signal, Value
from nmigen.hdl.dsl import Module
+from nmutil.singlepipe import ControlBase
def equal_leading_zero_count_reference(a, b, width):
return m
-# TODO: add CLDivRem
+
+@dataclass(frozen=True, unsafe_hash=True)
+class CLDivRemShape:
+ width: int
+ n_width: int
+
+ def __post_init__(self):
+ assert self.n_width >= self.width > 0
+
+ @property
+ def done_step(self):
+ return self.width
+
+ @property
+ def step_range(self):
+ return range(self.done_step + 1)
+
+
+@dataclass(frozen=True, eq=False)
+class CLDivRemState:
+ shape: CLDivRemShape
+ name: str
+ d: Signal = field(init=False)
+ r: Signal = field(init=False)
+ q: Signal = field(init=False)
+ step: Signal = field(init=False)
+
+ def __init__(self, shape, *, name=None, src_loc_at=0):
+ assert isinstance(shape, CLDivRemShape)
+ if name is None:
+ name = Signal(src_loc_at=1 + src_loc_at).name
+ assert isinstance(name, str)
+ d = Signal(2 * shape.width, name=f"{name}_d")
+ r = Signal(shape.n_width, name=f"{name}_r")
+ q = Signal(shape.width, name=f"{name}_q")
+ step = Signal(shape.width, name=f"{name}_step")
+ object.__setattr__(self, "shape", shape)
+ object.__setattr__(self, "name", name)
+ object.__setattr__(self, "d", d)
+ object.__setattr__(self, "r", r)
+ object.__setattr__(self, "q", q)
+ object.__setattr__(self, "step", step)
+
+ def eq(self, rhs):
+ assert isinstance(rhs, CLDivRemState)
+ for f in fields(CLDivRemState):
+ if f.name in ("shape", "name"):
+ continue
+ l = getattr(self, f.name)
+ r = getattr(rhs, f.name)
+ yield l.eq(r)
+
+ @staticmethod
+ def like(other, *, name=None, src_loc_at=0):
+ assert isinstance(other, CLDivRemState)
+ return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
+
+ @property
+ def done(self):
+ return self.will_be_done_after(steps=0)
+
+ def will_be_done_after(self, steps):
+ """ Returns True if this state will be done after
+ another `steps` passes through `set_to_next`."""
+ assert isinstance(steps, int) and steps >= 0
+ return self.step >= max(0, self.shape.done_step - steps)
+
+ def set_to_initial(self, m, n, d):
+ assert isinstance(m, Module)
+ m.d.comb += [
+ self.d.eq(Value.cast(d) << self.shape.width),
+ self.r.eq(n),
+ self.q.eq(0),
+ self.step.eq(0),
+ ]
+
+ def set_to_next(self, m, state_in):
+ assert isinstance(m, Module)
+ assert isinstance(state_in, CLDivRemState)
+ assert state_in.shape == self.shape
+ assert self is not state_in, "a.set_to_next(m, a) is not allowed"
+
+ equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
+ # can't name submodule since it would conflict if this function is
+ # called multiple times in a Module
+ m.submodules += equal_leading_zero_count
+
+ with m.If(state_in.done):
+ m.d.comb += self.eq(state_in)
+ with m.Else():
+ m.d.comb += [
+ self.step.eq(state_in.step + 1),
+ self.d.eq(state_in.d >> 1),
+ equal_leading_zero_count.a.eq(self.d),
+ equal_leading_zero_count.b.eq(state_in.r),
+ ]
+ d_top = self.d[self.shape.n_width:]
+ with m.If(equal_leading_zero_count.out & (d_top == 0)):
+ m.d.comb += [
+ self.r.eq(state_in.r ^ self.d),
+ self.q.eq((state_in.q << 1) | 1),
+ ]
+ with m.Else():
+ m.d.comb += [
+ self.r.eq(state_in.r),
+ self.q.eq(state_in.q << 1),
+ ]
+
+
+class CLDivRemInputData:
+ def __init__(self, shape):
+ assert isinstance(shape, CLDivRemShape)
+ self.shape = shape
+ self.n = Signal(shape.n_width)
+ self.d = Signal(shape.width)
+
+ def __iter__(self):
+ """ Get member signals. """
+ yield self.n
+ yield self.d
+
+ def eq(self, rhs):
+ """ Assign member signals. """
+ return [
+ self.n.eq(rhs.n),
+ self.d.eq(rhs.d),
+ ]
+
+
+class CLDivRemOutputData:
+ def __init__(self, shape):
+ assert isinstance(shape, CLDivRemShape)
+ self.shape = shape
+ self.q = Signal(shape.width)
+ self.r = Signal(shape.width)
+
+ def __iter__(self):
+ """ Get member signals. """
+ yield self.q
+ yield self.r
+
+ def eq(self, rhs):
+ """ Assign member signals. """
+ return [
+ self.q.eq(rhs.q),
+ self.r.eq(rhs.r),
+ ]
+
+
+class CLDivRemFSMStage(ControlBase):
+ """carry-less div/rem
+
+ Attributes:
+ shape: CLDivRemShape
+ the shape
+ steps_per_clock: int
+ number of steps that should be taken per clock cycle
+ in_valid: Signal()
+ input. true when the data inputs (`n` and `d`) are valid.
+ data transfer in occurs when `in_valid & in_ready`.
+ in_ready: Signal()
+ output. true when this FSM is ready to accept input.
+ data transfer in occurs when `in_valid & in_ready`.
+ n: Signal(shape.n_width)
+ numerator in, the value must be small enough that `q` and `r` don't
+ overflow. having `n_width == width` is sufficient.
+ d: Signal(shape.width)
+ denominator in, must be non-zero.
+ q: Signal(shape.width)
+ quotient out.
+ r: Signal(shape.width)
+ remainder out.
+ out_valid: Signal()
+ output. true when the data outputs (`q` and `r`) are valid
+ (or are junk because the inputs were out of range).
+ data transfer out occurs when `out_valid & out_ready`.
+ out_ready: Signal()
+ input. true when the output can be read.
+ data transfer out occurs when `out_valid & out_ready`.
+ """
+
+ def __init__(self, pspec, shape, *, steps_per_clock=4):
+ assert isinstance(shape, CLDivRemShape)
+ assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
+ self.shape = shape
+ self.steps_per_clock = steps_per_clock
+ self.pspec = pspec # store now: used in ispec and ospec
+ super().__init__(stage=self)
+ self.empty = Signal(reset=1)
+ self.saved_state = CLDivRemState(shape)
+
+ def ispec(self):
+ return CLDivRemInputData(self.shape)
+
+ def ospec(self):
+ return CLDivRemOutputData(self.shape)
+
+ def setup(self, m, i):
+ pass
+
+ def elaborate(self, platform):
+ m = super().elaborate(platform)
+ i_data: CLDivRemInputData = self.p.i_data
+ o_data: CLDivRemOutputData = self.n.o_data
+
+ # TODO: handle cancellation
+
+ state_will_be_done = self.saved_state.will_be_done_after(
+ self.steps_per_clock)
+ m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
+ m.d.comb += self.p.o_ready.eq(self.empty)
+
+ def make_nc(i):
+ return CLDivRemState(self.shape, name=f"next_chain_{i}")
+ next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
+ for i in range(self.steps_per_clock):
+ next_chain[i + 1].set_to_next(m, next_chain[i])
+ m.d.sync += self.saved_state.eq(next_chain[-1])
+ m.d.comb += o_data.q.eq(next_chain[-1].q)
+ m.d.comb += o_data.r.eq(next_chain[-1].r)
+
+ with m.If(self.empty):
+ next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
+ with m.If(self.p.i_valid):
+ m.d.sync += self.empty.eq(0)
+ with m.Else():
+ m.d.comb += next_chain[0].eq(self.saved_state)
+ with m.If(self.n.i_ready & self.n.o_valid):
+ m.d.sync += self.empty.eq(1)
+
+ return m
+
+ def __iter__(self):
+ yield from self.p
+ yield from self.n
+
+ def ports(self):
+ return list(self)
from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned
from nmigen.hdl.dsl import Module
from nmutil.formaltest import FHDLTestCase
-from nmigen_gf.hdl.cldivrem import (equal_leading_zero_count_reference,
+from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
+ CLDivRemOutputData, CLDivRemShape, CLDivRemState,
+ equal_leading_zero_count_reference,
EqualLeadingZeroCount)
-from nmigen.sim import Delay
+from nmigen.sim import Delay, Tick
from nmutil.sim_util import do_sim, hash_256
+from nmigen_gf.reference.cldivrem import cldivrem
class TestEqualLeadingZeroCount(FHDLTestCase):
def test_formal_3(self):
self.tst_formal(3)
-# TODO: add TestCLDivRem
+
+class TestCLDivRemComb(FHDLTestCase):
+ def tst(self, shape, full):
+ assert isinstance(shape, CLDivRemShape)
+ m = Module()
+ n_in = Signal(shape.n_width)
+ d_in = Signal(shape.width)
+ states: "list[CLDivRemState]" = []
+ for i in shape.step_range:
+ states.append(CLDivRemState(shape, name=f"state_{i}"))
+ if i == 0:
+ states[i].set_to_initial(m, n=n_in, d=d_in)
+ else:
+ states[i].set_to_next(m, states[i - 1])
+
+ def case(n, d):
+ assert isinstance(n, int)
+ assert isinstance(d, int)
+ max_width = max(shape.width, shape.n_width)
+ if d != 0:
+ expected_q, expected_r = cldivrem(n, d, width=max_width)
+ else:
+ expected_q = expected_r = 0
+ with self.subTest(n=hex(n), d=hex(d),
+ expected_q=hex(expected_q),
+ expected_r=hex(expected_r)):
+ yield n_in.eq(n)
+ yield d_in.eq(d)
+ yield Delay(1e-6)
+ for i in shape.step_range:
+ with self.subTest(i=i):
+ done = yield states[i].done
+ step = yield states[i].step
+ self.assertEqual(done, i >= shape.done_step)
+ self.assertEqual(step, i)
+ q = yield states[-1].q
+ r = yield states[-1].r
+ with self.subTest(q=hex(q), r=hex(r)):
+ # only check results when inputs are valid
+ if d != 0 and (expected_q >> shape.width) == 0:
+ self.assertEqual(q, expected_q)
+ self.assertEqual(r, expected_r)
+
+ def process():
+ if full:
+ for n in range(1 << shape.n_width):
+ for d in range(1 << shape.width):
+ yield from case(n, d)
+ else:
+ for i in range(100):
+ n = hash_256(f"cldivrem comb n {i}")
+ n = Const.normalize(n, unsigned(shape.n_width))
+ d = hash_256(f"cldivrem comb d {i}")
+ d = Const.normalize(d, unsigned(shape.width))
+ yield from case(n, d)
+ with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
+ sim.add_process(process)
+ sim.run()
+
+ def test_4(self):
+ self.tst(CLDivRemShape(width=4, n_width=4), full=True)
+
+ def test_8_by_4(self):
+ self.tst(CLDivRemShape(width=4, n_width=8), full=True)
+
+
+class TestCLDivRemFSM(FHDLTestCase):
+ def tst(self, shape, full, steps_per_clock):
+ assert isinstance(shape, CLDivRemShape)
+ assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
+ pspec = {}
+ dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
+ i_data: CLDivRemInputData = dut.p.i_data
+ o_data: CLDivRemOutputData = dut.n.o_data
+ self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
+ self.assertEqual(i_data.d.shape(), unsigned(shape.width))
+ self.assertEqual(o_data.q.shape(), unsigned(shape.width))
+ self.assertEqual(o_data.r.shape(), unsigned(shape.width))
+
+ def case(n, d):
+ assert isinstance(n, int)
+ assert isinstance(d, int)
+ max_width = max(shape.width, shape.n_width)
+ if d != 0:
+ expected_q, expected_r = cldivrem(n, d, width=max_width)
+ else:
+ expected_q = expected_r = 0
+ with self.subTest(n=hex(n), d=hex(d),
+ expected_q=hex(expected_q),
+ expected_r=hex(expected_r)):
+ yield dut.p.i_valid.eq(0)
+ yield Tick()
+ yield i_data.n.eq(n)
+ yield i_data.d.eq(d)
+ yield dut.p.i_valid.eq(1)
+ yield Delay(0.1e-6)
+ valid = yield dut.n.o_valid
+ ready = yield dut.p.o_ready
+ with self.subTest():
+ self.assertFalse(valid)
+ self.assertTrue(ready)
+ yield Tick()
+ yield i_data.n.eq(-1)
+ yield i_data.d.eq(-1)
+ yield dut.p.i_valid.eq(0)
+ for i in range(steps_per_clock * 2, shape.done_step,
+ steps_per_clock):
+ yield Delay(0.1e-6)
+ valid = yield dut.n.o_valid
+ ready = yield dut.p.o_ready
+ with self.subTest():
+ self.assertFalse(valid)
+ self.assertFalse(ready)
+ yield Tick()
+ yield Delay(0.1e-6)
+ valid = yield dut.n.o_valid
+ ready = yield dut.p.o_ready
+ with self.subTest():
+ self.assertTrue(valid)
+ self.assertFalse(ready)
+ q = yield o_data.q
+ r = yield o_data.r
+ with self.subTest(q=hex(q), r=hex(r)):
+ # only check results when inputs are valid
+ if d != 0 and (expected_q >> shape.width) == 0:
+ self.assertEqual(q, expected_q)
+ self.assertEqual(r, expected_r)
+ yield dut.n.i_ready.eq(1)
+ yield Tick()
+ yield Delay(0.1e-6)
+ valid = yield dut.n.o_valid
+ ready = yield dut.p.o_ready
+ with self.subTest():
+ self.assertFalse(valid)
+ self.assertTrue(ready)
+ yield dut.n.i_ready.eq(0)
+
+ def process():
+ if full:
+ for n in range(1 << shape.n_width):
+ for d in range(1 << shape.width):
+ yield from case(n, d)
+ else:
+ for i in range(100):
+ n = hash_256(f"cldivrem fsm n {i}")
+ n = Const.normalize(n, unsigned(shape.n_width))
+ d = hash_256(f"cldivrem fsm d {i}")
+ d = Const.normalize(d, unsigned(shape.width))
+ yield from case(n, d)
+
+ with do_sim(self, dut, list(dut.ports())) as sim:
+ sim.add_process(process)
+ sim.add_clock(1e-6)
+ sim.run()
+
+ def test_4_step_1(self):
+ self.tst(CLDivRemShape(width=4, n_width=4),
+ full=True,
+ steps_per_clock=1)
+
+ def test_4_step_2(self):
+ self.tst(CLDivRemShape(width=4, n_width=4),
+ full=True,
+ steps_per_clock=2)
+
+ def test_4_step_3(self):
+ self.tst(CLDivRemShape(width=4, n_width=4),
+ full=True,
+ steps_per_clock=3)
if __name__ == "__main__":