assert isinstance(d, int) and 0 <= d < 1 << width
assert d != 0, "TODO: decide what happens on division by zero"
- shift_wid = (width - 1).bit_length()
+ shape = CLDivRemShape(width)
# `clz(d, width)`, but maxes out at `width - 1` instead of `width` in
- # order to both fit in `shift_wid` bits and not shift by more than needed.
+ # order to both fit in `shape.shift_width` bits and to not shift by more
+ # than needed.
shift = clz(d >> 1, width - 1)
- assert shift < 1 << shift_wid, f"shift overflows a {shift_wid}-bit signal"
+ assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
d <<= shift
- assert d < 1 << width, f"d overflows a {width}-bit signal"
- n <<= shift
- assert n < 1 << (width * 2), f"n overflows a {width * 2}-bit signal"
- r = n
+ assert 0 <= d < 1 << shape.d_width, "d overflow"
+ r = n << shift
+ assert 0 <= r < 1 << shape.r_width, "r overflow"
q = 0
- for _ in range(width):
+ for step in range(width):
q <<= 1
r <<= 1
if r >> (width * 2 - 1) != 0:
r ^= d << width
q |= 1
- assert q < 1 << width, f"q overflows a {width}-bit signal"
- assert r < 1 << (width * 2), f"r overflows a {width * 2}-bit signal"
+ assert 0 <= q < 1 << shape.q_width, "q overflow"
+ assert 0 <= r < 1 << shape.r_width, "r overflow"
r >>= width
r >>= shift
return q, r
@dataclass(frozen=True, unsafe_hash=True)
class CLDivRemShape:
width: int
- n_width: int
def __post_init__(self):
- assert self.n_width >= self.width > 0
+ assert isinstance(self.width, int) and self.width >= 1, "invalid width"
@property
def done_step(self):
+ """the step number when iteration is finished
+ -- the largest `CLDivRemState.step` will get
+ """
return self.width
@property
def step_range(self):
+ """the range that `CLDivRemState.step` will fall in.
+
+ returns: range
+ """
return range(self.done_step + 1)
+ @property
+ def d_width(self):
+ """bit-width of the internal signal `CLDivRemState.d`"""
+ return self.width
+
+ @property
+ def r_width(self):
+ """bit-width of the internal signal `CLDivRemState.r`"""
+ return self.width * 2
+
+ @property
+ def q_width(self):
+ """bit-width of the internal signal `CLDivRemState.q`"""
+ return self.width
+
+ @property
+ def shift_width(self):
+ """bit-width of the internal signal `CLDivRemState.shift`"""
+ return (self.width - 1).bit_length()
+
@dataclass(frozen=True, eq=False)
class CLDivRemState:
shape: CLDivRemShape
name: str
+ step: Signal = field(init=False)
d: Signal = field(init=False)
r: Signal = field(init=False)
q: Signal = field(init=False)
- step: Signal = field(init=False)
+ shift: Signal = field(init=False)
def __init__(self, shape, *, name=None, src_loc_at=0):
assert isinstance(shape, CLDivRemShape)
if name is None:
name = Signal(src_loc_at=1 + src_loc_at).name
assert isinstance(name, str)
- d = Signal(2 * shape.width, name=f"{name}_d")
- r = Signal(shape.n_width, name=f"{name}_r")
- q = Signal(shape.width, name=f"{name}_q")
- step = Signal(shape.width, name=f"{name}_step")
+ step = Signal(shape.step_range, name=f"{name}_step")
+ d = Signal(shape.d_width, name=f"{name}_d")
+ r = Signal(shape.r_width, name=f"{name}_r")
+ q = Signal(shape.q_width, name=f"{name}_q")
+ shift = Signal(shape.shift_width, name=f"{name}_shift")
object.__setattr__(self, "shape", shape)
object.__setattr__(self, "name", name)
+ object.__setattr__(self, "step", step)
object.__setattr__(self, "d", d)
object.__setattr__(self, "r", r)
object.__setattr__(self, "q", q)
- object.__setattr__(self, "step", step)
+ object.__setattr__(self, "shift", shift)
def eq(self, rhs):
assert isinstance(rhs, CLDivRemState)
assert isinstance(steps, int) and steps >= 0
return self.step >= max(0, self.shape.done_step - steps)
+ def get_output(self):
+ return self.q, (self.r >> self.shape.width) >> self.shift
+
def set_to_initial(self, m, n, d):
assert isinstance(m, Module)
+ n = Value.cast(n) # convert to Value
+ d = Value.cast(d) # convert to Value
+ clz_mod = CLZ(self.shape.width - 1)
+ # can't name submodule since it would conflict if this function is
+ # called multiple times in a Module
+ m.submodules += clz_mod
+ assert clz_mod.lz.width == self.shape.shift_width, \
+ "internal inconsistency -- mismatched shift signal width"
m.d.comb += [
- self.d.eq(Value.cast(d) << self.shape.width),
- self.r.eq(n),
+ clz_mod.sig_in.eq(d >> 1),
+ self.shift.eq(clz_mod.lz),
+ self.d.eq(d << self.shift),
+ self.r.eq(n << self.shift),
self.q.eq(0),
self.step.eq(0),
]
assert isinstance(state_in, CLDivRemState)
assert state_in.shape == self.shape
assert self is not state_in, "a.set_to_next(m, a) is not allowed"
-
- equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
- # can't name submodule since it would conflict if this function is
- # called multiple times in a Module
- m.submodules += equal_leading_zero_count
+ width = self.shape.width
with m.If(state_in.done):
m.d.comb += self.eq(state_in)
with m.Else():
m.d.comb += [
self.step.eq(state_in.step + 1),
- self.d.eq(state_in.d >> 1),
- equal_leading_zero_count.a.eq(self.d),
- equal_leading_zero_count.b.eq(state_in.r),
+ self.d.eq(state_in.d),
+ self.shift.eq(state_in.shift),
]
- d_top = self.d[self.shape.n_width:]
- with m.If(equal_leading_zero_count.out & (d_top == 0)):
+ q = state_in.q << 1
+ r = state_in.r << 1
+ with m.If(r[width * 2 - 1]):
m.d.comb += [
- self.r.eq(state_in.r ^ self.d),
- self.q.eq((state_in.q << 1) | 1),
+ self.q.eq(q | 1),
+ self.r.eq(r ^ (state_in.d << width)),
]
with m.Else():
m.d.comb += [
- self.r.eq(state_in.r),
- self.q.eq(state_in.q << 1),
+ self.q.eq(q),
+ self.r.eq(r),
]
def __init__(self, shape):
assert isinstance(shape, CLDivRemShape)
self.shape = shape
- self.n = Signal(shape.n_width)
+ self.n = Signal(shape.width)
self.d = Signal(shape.width)
def __iter__(self):
the shape
steps_per_clock: int
number of steps that should be taken per clock cycle
- in_valid: Signal()
- input. true when the data inputs (`n` and `d`) are valid.
- data transfer in occurs when `in_valid & in_ready`.
- in_ready: Signal()
- output. true when this FSM is ready to accept input.
- data transfer in occurs when `in_valid & in_ready`.
- n: Signal(shape.n_width)
- numerator in, the value must be small enough that `q` and `r` don't
- overflow. having `n_width == width` is sufficient.
- d: Signal(shape.width)
- denominator in, must be non-zero.
- q: Signal(shape.width)
- quotient out.
- r: Signal(shape.width)
- remainder out.
- out_valid: Signal()
- output. true when the data outputs (`q` and `r`) are valid
- (or are junk because the inputs were out of range).
- data transfer out occurs when `out_valid & out_ready`.
- out_ready: Signal()
- input. true when the output can be read.
- data transfer out occurs when `out_valid & out_ready`.
+ pspec:
+ pipe-spec
+ empty: Signal()
+ true if nothing is stored in `self.saved_state`
+ saved_state: CLDivRemState()
+ the saved state that is currently being worked on.
"""
- def __init__(self, pspec, shape, *, steps_per_clock=4):
+ def __init__(self, pspec, shape, *, steps_per_clock=8):
assert isinstance(shape, CLDivRemShape)
assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
self.shape = shape
# TODO: handle cancellation
- state_will_be_done = self.saved_state.will_be_done_after(
- self.steps_per_clock)
- m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
+ m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done)
m.d.comb += self.p.o_ready.eq(self.empty)
def make_nc(i):
next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
for i in range(self.steps_per_clock):
next_chain[i + 1].set_to_next(m, next_chain[i])
- m.d.sync += self.saved_state.eq(next_chain[-1])
- m.d.comb += o_data.q.eq(next_chain[-1].q)
- m.d.comb += o_data.r.eq(next_chain[-1].r)
+ m.d.comb += next_chain[0].eq(self.saved_state)
+ out_q, out_r = self.saved_state.get_output()
+ m.d.comb += o_data.q.eq(out_q)
+ m.d.comb += o_data.r.eq(out_r)
+ initial_state = CLDivRemState(self.shape)
+ initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
with m.If(self.empty):
- next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
+ m.d.sync += self.saved_state.eq(initial_state)
with m.If(self.p.i_valid):
m.d.sync += self.empty.eq(0)
with m.Else():
- m.d.comb += next_chain[0].eq(self.saved_state)
+ m.d.sync += self.saved_state.eq(next_chain[-1])
with m.If(self.n.i_ready & self.n.o_valid):
m.d.sync += self.empty.eq(1)
-
return m
def __iter__(self):
class TestCLDivRemComb(FHDLTestCase):
def tst(self, shape, full):
assert isinstance(shape, CLDivRemShape)
+ width = shape.width
m = Module()
- n_in = Signal(shape.n_width)
- d_in = Signal(shape.width)
+ n_in = Signal(width)
+ d_in = Signal(width)
+ q_out = Signal(width)
+ r_out = Signal(width)
states: "list[CLDivRemState]" = []
for i in shape.step_range:
states.append(CLDivRemState(shape, name=f"state_{i}"))
states[i].set_to_initial(m, n=n_in, d=d_in)
else:
states[i].set_to_next(m, states[i - 1])
+ q, r = states[-1].get_output()
+ m.d.comb += [q_out.eq(q), r_out.eq(r)]
def case(n, d):
assert isinstance(n, int)
assert isinstance(d, int)
- max_width = max(shape.width, shape.n_width)
if d != 0:
- expected_q, expected_r = cldivrem(n, d, width=max_width)
+ expected_q, expected_r = cldivrem_shifting(n, d, width)
else:
expected_q = expected_r = 0
with self.subTest(n=hex(n), d=hex(d),
step = yield states[i].step
self.assertEqual(done, i >= shape.done_step)
self.assertEqual(step, i)
- q = yield states[-1].q
- r = yield states[-1].r
+ q = yield q_out
+ r = yield r_out
with self.subTest(q=hex(q), r=hex(r)):
# only check results when inputs are valid
- if d != 0 and (expected_q >> shape.width) == 0:
+ if d != 0:
self.assertEqual(q, expected_q)
self.assertEqual(r, expected_r)
def process():
if full:
- for n in range(1 << shape.n_width):
- for d in range(1 << shape.width):
+ for n in range(1 << width):
+ for d in range(1 << width):
yield from case(n, d)
else:
for i in range(100):
n = hash_256(f"cldivrem comb n {i}")
- n = Const.normalize(n, unsigned(shape.n_width))
+ n = Const.normalize(n, unsigned(width))
d = hash_256(f"cldivrem comb d {i}")
- d = Const.normalize(d, unsigned(shape.width))
+ d = Const.normalize(d, unsigned(width))
yield from case(n, d)
- with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
+ with do_sim(self, m, [n_in, d_in, q_out, r_out]) as sim:
sim.add_process(process)
sim.run()
def test_4(self):
- self.tst(CLDivRemShape(width=4, n_width=4), full=True)
+ self.tst(CLDivRemShape(width=4), full=True)
+
+ def test_6(self):
+ self.tst(CLDivRemShape(width=6), full=True)
- def test_8_by_4(self):
- self.tst(CLDivRemShape(width=4, n_width=8), full=True)
+ def test_8(self):
+ self.tst(CLDivRemShape(width=8), full=False)
class TestCLDivRemFSM(FHDLTestCase):
dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
i_data: CLDivRemInputData = dut.p.i_data
o_data: CLDivRemOutputData = dut.n.o_data
- self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
+ self.assertEqual(i_data.n.shape(), unsigned(shape.width))
self.assertEqual(i_data.d.shape(), unsigned(shape.width))
self.assertEqual(o_data.q.shape(), unsigned(shape.width))
self.assertEqual(o_data.r.shape(), unsigned(shape.width))
def case(n, d):
assert isinstance(n, int)
assert isinstance(d, int)
- max_width = max(shape.width, shape.n_width)
if d != 0:
- expected_q, expected_r = cldivrem(n, d, width=max_width)
+ expected_q, expected_r = cldivrem(n, d, width=shape.width)
else:
expected_q = expected_r = 0
with self.subTest(n=hex(n), d=hex(d),
yield i_data.n.eq(-1)
yield i_data.d.eq(-1)
yield dut.p.i_valid.eq(0)
- for i in range(steps_per_clock * 2, shape.done_step,
- steps_per_clock):
+ for step in range(0, shape.done_step, steps_per_clock):
yield Delay(0.1e-6)
valid = yield dut.n.o_valid
ready = yield dut.p.o_ready
def process():
if full:
- for n in range(1 << shape.n_width):
+ for n in range(1 << shape.width):
for d in range(1 << shape.width):
yield from case(n, d)
else:
for i in range(100):
n = hash_256(f"cldivrem fsm n {i}")
- n = Const.normalize(n, unsigned(shape.n_width))
+ n = Const.normalize(n, unsigned(shape.width))
d = hash_256(f"cldivrem fsm d {i}")
d = Const.normalize(d, unsigned(shape.width))
yield from case(n, d)
sim.run()
def test_4_step_1(self):
- self.tst(CLDivRemShape(width=4, n_width=4),
+ self.tst(CLDivRemShape(width=4),
full=True,
steps_per_clock=1)
def test_4_step_2(self):
- self.tst(CLDivRemShape(width=4, n_width=4),
+ self.tst(CLDivRemShape(width=4),
full=True,
steps_per_clock=2)
def test_4_step_3(self):
- self.tst(CLDivRemShape(width=4, n_width=4),
+ self.tst(CLDivRemShape(width=4),
full=True,
steps_per_clock=3)
+ def test_4_step_4(self):
+ self.tst(CLDivRemShape(width=4),
+ full=True,
+ steps_per_clock=4)
+
+ def test_8_step_4(self):
+ self.tst(CLDivRemShape(width=8),
+ full=False,
+ steps_per_clock=4)
+
+ def test_64_step_4(self):
+ self.tst(CLDivRemShape(width=64),
+ full=False,
+ steps_per_clock=4)
+
+ def test_64_step_8(self):
+ self.tst(CLDivRemShape(width=64),
+ full=False,
+ steps_per_clock=8)
+
if __name__ == "__main__":
unittest.main()