From 6e87ea5bc82330a687465f8921dbe39c2570017b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 17 Jul 2020 20:16:27 -0700 Subject: [PATCH] add div fsm core (`DivState*`) with tests comb test works fsm test fails for some reason --- src/soc/fu/div/fsm.py | 131 ++++++++++++--- src/soc/fu/div/pipeline.py | 33 +++- src/soc/fu/div/test/test_fsm.py | 288 ++++++++++++++++++++++++++++++++ 3 files changed, 419 insertions(+), 33 deletions(-) create mode 100644 src/soc/fu/div/test/test_fsm.py diff --git a/src/soc/fu/div/fsm.py b/src/soc/fu/div/fsm.py index bae1a43e..10bc80d3 100644 --- a/src/soc/fu/div/fsm.py +++ b/src/soc/fu/div/fsm.py @@ -1,6 +1,8 @@ import enum -from nmigen import Elaboratable, Module, Signal -from soc.fu.div.pipe_data import CoreInputData, CoreOutputData +from nmigen import Elaboratable, Module, Signal, Shape, unsigned, Cat, Mux +from soc.fu.div.pipe_data import CoreInputData, CoreOutputData, DivPipeSpec +from nmutil.iocontrol import PrevControl, NextControl +from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation class FSMDivCoreConfig: @@ -49,51 +51,130 @@ class FSMDivCoreOutputData: self.remainder.eq(rhs.remainder)] -class FSMDivCorePrev: +class FSMDivCorePrevControl(PrevControl): + data_i: CoreInputData + def __init__(self, pspec): + super().__init__(stage_ctl=True, maskwid=pspec.id_wid) + self.pspec = pspec self.data_i = CoreInputData(pspec) - self.valid_i = Signal() - self.ready_o = Signal() - def __iter__(self): - yield from self.data_i - yield self.valid_i - yield self.ready_o +class FSMDivCoreNextControl(NextControl): + data_o: CoreOutputData -class FSMDivCoreNext: def __init__(self, pspec): + super().__init__(stage_ctl=True, maskwid=pspec.id_wid) + self.pspec = pspec self.data_o = CoreOutputData(pspec) - self.valid_o = Signal() - self.ready_i = Signal() - def __iter__(self): - yield from self.data_o - yield self.valid_o - yield self.ready_i +class DivStateNext(Elaboratable): + def __init__(self, quotient_width): + self.quotient_width = quotient_width + self.i = DivState(quotient_width=quotient_width, name="i") + self.divisor = Signal(quotient_width) + self.o = DivState(quotient_width=quotient_width, name="o") + + def elaborate(self, platform): + m = Module() + difference = Signal(self.i.quotient_width * 2) + m.d.comb += difference.eq(self.i.dividend_quotient + - (self.divisor + << (self.quotient_width - 1))) + next_quotient_bit = Signal() + m.d.comb += next_quotient_bit.eq( + ~difference[self.quotient_width * 2 - 1]) + value = Signal(self.i.quotient_width * 2) + with m.If(next_quotient_bit): + m.d.comb += value.eq(difference) + with m.Else(): + m.d.comb += value.eq(self.i.dividend_quotient) + + with m.If(self.i.done): + m.d.comb += self.o.eq(self.i) + with m.Else(): + m.d.comb += [ + self.o.q_bits_known.eq(self.i.q_bits_known + 1), + self.o.dividend_quotient.eq(Cat(next_quotient_bit, value))] + return m + + +class DivStateInit(Elaboratable): + def __init__(self, quotient_width): + self.quotient_width = quotient_width + self.dividend = Signal(quotient_width * 2) + self.o = DivState(quotient_width=quotient_width, name="o") + + def elaborate(self, platform): + m = Module() + m.d.comb += self.o.q_bits_known.eq(0) + m.d.comb += self.o.dividend_quotient.eq(self.dividend) + return m + + +class DivState: + def __init__(self, quotient_width, name): + self.quotient_width = quotient_width + self.q_bits_known = Signal(range(1 + quotient_width), + name=name + "_q_bits_known") + self.dividend_quotient = Signal(unsigned(2 * quotient_width), + name=name + "_dividend_quotient") -class DivState(enum.Enum): - Empty = 0 - Computing = 1 - WaitingOnOutput = 2 + @property + def done(self): + return self.q_bits_known == self.quotient_width + + @property + def quotient(self): + """ get the quotient -- requires self.done is True """ + return self.dividend_quotient[0:self.quotient_width] + + @property + def remainder(self): + """ get the remainder -- requires self.done is True """ + return self.dividend_quotient[self.quotient_width:self.quotient_width*2] + + def eq(self, rhs): + return [self.q_bits_known.eq(rhs.q_bits_known), + self.dividend_quotient.eq(rhs.dividend_quotient)] class FSMDivCoreStage(Elaboratable): - def __init__(self, pspec): - self.p = FSMDivCorePrev(pspec) - self.n = FSMDivCoreNext(pspec) + def __init__(self, pspec: DivPipeSpec): + self.pspec = pspec + self.p = FSMDivCorePrevControl(pspec) + self.n = FSMDivCoreNextControl(pspec) self.saved_input_data = CoreInputData(pspec) self.canceled = Signal() - self.state = Signal(DivState, reset=DivState.Empty) + self.empty = Signal(reset=1) + self.saved_state = DivState(64) def elaborate(self, platform): m = Module() + m.submodules.p = self.p + m.submodules.n = self.n + data_i = self.p.data_i + data_o = self.p.data_o # TODO: calculate self.canceled from self.p.data_i.ctx m.d.comb += self.canceled.eq(False) - # TODO(programmerjake): finish + # TODO: adapt to refactored DivState interface + fsm_state_in = DivState(64) + divisor = Signal(unsigned(64)) + fsm_state_out = fsm_state_in.make_next_state(m, divisor) + + with m.If(self.canceled): + with m.If(self.p.valid_i): + ... + with m.Else(): + ... + with m.Else(): + with m.If(self.p.valid_i): + ... + with m.Else(): + ... return m diff --git a/src/soc/fu/div/pipeline.py b/src/soc/fu/div/pipeline.py index 2ea291e9..80f5a94b 100644 --- a/src/soc/fu/div/pipeline.py +++ b/src/soc/fu/div/pipeline.py @@ -6,18 +6,26 @@ from soc.fu.div.output_stage import DivOutputStage from soc.fu.div.setup_stage import DivSetupStage from soc.fu.div.core_stages import (DivCoreSetupStage, DivCoreCalculateStage, DivCoreFinalStage) +from soc.fu.div.pipe_data import DivPipeKindConfigCombPipe class DivStagesStart(PipeModBaseChain): def get_chain(self): alu_input = DivMulInputStage(self.pspec) div_setup = DivSetupStage(self.pspec) - core_setup = DivCoreSetupStage(self.pspec) - return [alu_input, div_setup, core_setup] + if isinstance(self.pspec.div_pipe_kind.config, + DivPipeKindConfigCombPipe): + core_setup = [DivCoreSetupStage(self.pspec)] + else: + core_setup = () + return [alu_input, div_setup, *core_setup] class DivStagesMiddle(PipeModBaseChain): def __init__(self, pspec, stage_start_index, stage_end_index): + assert isinstance(pspec.div_pipe_kind.config, + DivPipeKindConfigCombPipe),\ + "DivStagesMiddle must be used with a DivPipeKindConfigCombPipe" self.stage_start_index = stage_start_index self.stage_end_index = stage_end_index super().__init__(pspec) @@ -31,11 +39,15 @@ class DivStagesMiddle(PipeModBaseChain): class DivStagesEnd(PipeModBaseChain): def get_chain(self): - core_final = DivCoreFinalStage(self.pspec) + if isinstance(self.pspec.div_pipe_kind.config, + DivPipeKindConfigCombPipe): + core_final = [DivCoreFinalStage(self.pspec)] + else: + core_final = () div_out = DivOutputStage(self.pspec) alu_out = DivMulOutputStage(self.pspec) self.div_out = div_out # debugging - bug #425 - return [core_final, div_out, alu_out] + return [*core_final, div_out, alu_out] class DivBasePipe(ControlBase): @@ -43,11 +55,16 @@ class DivBasePipe(ControlBase): ControlBase.__init__(self) self.pspec = pspec self.pipe_start = DivStagesStart(pspec) - compute_steps = pspec.core_config.n_stages self.pipe_middles = [] - for start in range(0, compute_steps, compute_steps_per_stage): - end = min(start + compute_steps_per_stage, compute_steps) - self.pipe_middles.append(DivStagesMiddle(pspec, start, end)) + if isinstance(self.pspec.div_pipe_kind.config, + DivPipeKindConfigCombPipe): + compute_steps = pspec.core_config.n_stages + for start in range(0, compute_steps, compute_steps_per_stage): + end = min(start + compute_steps_per_stage, compute_steps) + self.pipe_middles.append(DivStagesMiddle(pspec, start, end)) + else: + self.pipe_middles.append( + self.pspec.div_pipe_kind.config.core_stage_class(pspec)) self.pipe_end = DivStagesEnd(pspec) self._eqs = self.connect([self.pipe_start, *self.pipe_middles, diff --git a/src/soc/fu/div/test/test_fsm.py b/src/soc/fu/div/test/test_fsm.py new file mode 100644 index 00000000..75f606fc --- /dev/null +++ b/src/soc/fu/div/test/test_fsm.py @@ -0,0 +1,288 @@ +import unittest +from soc.fu.div.fsm import DivState, DivStateInit, DivStateNext +from nmigen import Elaboratable, Module, Signal, unsigned +from nmigen.cli import rtlil +from nmigen.sim.pysim import Simulator, Delay, Tick + + +class CheckEvent(Elaboratable): + """helper to add indication to vcd when signals are checked""" + + def __init__(self): + self.event = Signal() + + def trigger(self): + yield self.event.eq(~self.event) + + def elaborate(self, platform): + m = Module() + # use event somehow so nmigen simulation knows about it + m.d.comb += Signal().eq(self.event) + return m + + +class DivStateCombTest(Elaboratable): + """Test stringing a bunch of copies of the FSM state-function together""" + + def __init__(self, quotient_width): + self.check_event = CheckEvent() + self.quotient_width = quotient_width + self.dividend = Signal(unsigned(quotient_width * 2)) + self.divisor = Signal(unsigned(quotient_width)) + self.quotient = Signal(unsigned(quotient_width)) + self.remainder = Signal(unsigned(quotient_width)) + self.expected_quotient = Signal(unsigned(quotient_width)) + self.expected_remainder = Signal(unsigned(quotient_width)) + self.expected_valid = Signal() + self.states = [] + for i in range(quotient_width + 1): + state = DivState(quotient_width=quotient_width, name=f"state{i}") + self.states.append(state) + self.init = DivStateInit(quotient_width) + self.nexts = [] + for i in range(quotient_width): + next = DivStateNext(quotient_width) + self.nexts.append(next) + + def elaborate(self, platform): + m = Module() + m.submodules.check_event = self.check_event + m.submodules.init = self.init + m.d.comb += self.init.dividend.eq(self.dividend) + m.d.comb += self.states[0].eq(self.init.o) + last_state = self.states[0] + for i in range(self.quotient_width): + setattr(m.submodules, f"next{i}", self.nexts[i]) + m.d.comb += self.nexts[i].divisor.eq(self.divisor) + m.d.comb += self.nexts[i].i.eq(last_state) + last_state = self.states[i + 1] + m.d.comb += last_state.eq(self.nexts[i].o) + m.d.comb += self.quotient.eq(last_state.quotient) + m.d.comb += self.remainder.eq(last_state.remainder) + m.d.comb += self.expected_valid.eq( + (self.dividend < (self.divisor << self.quotient_width)) + & (self.divisor != 0)) + with m.If(self.expected_valid): + m.d.comb += self.expected_quotient.eq( + self.dividend // self.divisor) + m.d.comb += self.expected_remainder.eq( + self.dividend % self.divisor) + return m + + +class DivStateFSMTest(Elaboratable): + def __init__(self, quotient_width): + self.check_done_event = CheckEvent() + self.check_event = CheckEvent() + self.quotient_width = quotient_width + self.dividend = Signal(unsigned(quotient_width * 2)) + self.divisor = Signal(unsigned(quotient_width)) + self.quotient = Signal(unsigned(quotient_width)) + self.remainder = Signal(unsigned(quotient_width)) + self.expected_quotient = Signal(unsigned(quotient_width)) + self.expected_remainder = Signal(unsigned(quotient_width)) + self.expected_valid = Signal() + self.state = DivState(quotient_width=quotient_width, + name="state") + self.next_state = DivState(quotient_width=quotient_width, + name="next_state") + self.init = DivStateInit(quotient_width) + self.next = DivStateNext(quotient_width) + self.state_done = Signal() + self.next_state_done = Signal() + self.clear = Signal(reset=1) + + def elaborate(self, platform): + m = Module() + m.submodules.check_event = self.check_event + m.submodules.check_done_event = self.check_done_event + m.submodules.init = self.init + m.submodules.next = self.next + m.d.comb += self.init.dividend.eq(self.dividend) + m.d.comb += self.next.divisor.eq(self.divisor) + m.d.comb += self.quotient.eq(self.state.quotient) + m.d.comb += self.remainder.eq(self.state.remainder) + m.d.comb += self.next.i.eq(self.state) + m.d.comb += self.state_done.eq(self.state.done) + m.d.comb += self.next_state_done.eq(self.next_state.done) + + with m.If(self.state.done | self.clear): + m.d.comb += self.next_state.eq(self.init.o) + with m.Else(): + m.d.comb += self.next_state.eq(self.next.o) + + m.d.sync += self.state.eq(self.next_state) + + m.d.comb += self.expected_valid.eq( + (self.dividend < (self.divisor << self.quotient_width)) + & (self.divisor != 0)) + with m.If(self.expected_valid): + m.d.comb += self.expected_quotient.eq( + self.dividend // self.divisor) + m.d.comb += self.expected_remainder.eq( + self.dividend % self.divisor) + return m + + +def get_cases(quotient_width): + test_cases = [] + mask = ~(~0 << quotient_width) + for i in range(-3, 4): + test_cases.append(i & mask) + for i in [-1, 0, 1]: + test_cases.append((i + (mask >> 1)) & mask) + test_cases.sort() + return test_cases + + +class TestDivState(unittest.TestCase): + def test_div_state_comb(self, quotient_width=8): + test_cases = get_cases(quotient_width) + mask = ~(~0 << quotient_width) + dut = DivStateCombTest(quotient_width) + vl = rtlil.convert(dut, + ports=[dut.dividend, + dut.divisor, + dut.quotient, + dut.remainder]) + with open("div_fsm_comb_pipeline.il", "w") as f: + f.write(vl) + dut = DivStateCombTest(quotient_width) + + def check(dividend, divisor): + with self.subTest(dividend=f"{dividend:#x}", + divisor=f"{divisor:#x}"): + yield from dut.check_event.trigger() + for i in range(quotient_width + 1): + # done must be correct and eventually true + # even if a div-by-zero or overflow occurred + done = yield dut.states[i].done + self.assertEqual(done, i == quotient_width) + if divisor != 0: + quotient = dividend // divisor + remainder = dividend % divisor + if quotient <= mask: + with self.subTest(quotient=f"{quotient:#x}", + remainder=f"{remainder:#x}"): + self.assertTrue((yield dut.expected_valid)) + self.assertEqual((yield dut.expected_quotient), quotient) + self.assertEqual((yield dut.expected_remainder), remainder) + self.assertEqual((yield dut.quotient), quotient) + self.assertEqual((yield dut.remainder), remainder) + else: + self.assertFalse((yield dut.expected_valid)) + else: + self.assertFalse((yield dut.expected_valid)) + + def process(gen): + for dividend_high in test_cases: + for dividend_low in test_cases: + dividend = dividend_low + \ + (dividend_high << quotient_width) + for divisor in test_cases: + if gen: + yield Delay(0.5e-6) + yield dut.dividend.eq(dividend) + yield dut.divisor.eq(divisor) + yield Delay(0.5e-6) + else: + yield Delay(1e-6) + yield from check(dividend, divisor) + + def gen_process(): + yield from process(gen=True) + + def check_process(): + yield from process(gen=False) + + sim = Simulator(dut) + with sim.write_vcd(vcd_file="div_fsm_comb_pipeline.vcd", + gtkw_file="div_fsm_comb_pipeline.gtkw"): + + sim.add_process(gen_process) + sim.add_process(check_process) + sim.run() + + def test_div_state_fsm(self, quotient_width=8): + # TODO(programmerjake): fix test: for some reason + # the check process is delayed to the second division + # before it tries to do the first check + test_cases = get_cases(quotient_width) + mask = ~(~0 << quotient_width) + dut = DivStateFSMTest(quotient_width) + vl = rtlil.convert(dut, + ports=[dut.dividend, + dut.divisor, + dut.quotient, + dut.remainder]) + with open("div_fsm.il", "w") as f: + f.write(vl) + + def check(dividend, divisor): + with self.subTest(dividend=f"{dividend:#x}", + divisor=f"{divisor:#x}"): + for i in range(quotient_width + 1): + yield Tick() + yield Delay(0.1e-6) + yield from dut.check_done_event.trigger() + # done must be correct and eventually true + # even if a div-by-zero or overflow occurred + done = yield dut.state.done + self.assertEqual(done, i == quotient_width) + yield from dut.check_event.trigger() + if divisor != 0: + quotient = dividend // divisor + remainder = dividend % divisor + if quotient <= mask: + with self.subTest(quotient=f"{quotient:#x}", + remainder=f"{remainder:#x}"): + self.assertTrue((yield dut.expected_valid)) + self.assertEqual((yield dut.expected_quotient), quotient) + self.assertEqual((yield dut.expected_remainder), remainder) + self.assertEqual((yield dut.quotient), quotient) + self.assertEqual((yield dut.remainder), remainder) + else: + self.assertFalse((yield dut.expected_valid)) + else: + self.assertFalse((yield dut.expected_valid)) + + def process(gen): + if gen: + yield dut.clear.eq(1) + else: + yield from dut.check_event.trigger() + yield from dut.check_done_event.trigger() + yield Tick() + for dividend_high in test_cases: + for dividend_low in test_cases: + dividend = dividend_low + \ + (dividend_high << quotient_width) + for divisor in test_cases: + if gen: + yield Delay(0.2e-6) + yield dut.clear.eq(0) + yield dut.dividend.eq(dividend) + yield dut.divisor.eq(divisor) + for _ in range(quotient_width): + yield Tick() + else: + yield from check(dividend, divisor) + + def gen_process(): + yield from process(gen=True) + + def check_process(): + yield from process(gen=False) + + sim = Simulator(dut) + with sim.write_vcd(vcd_file="div_fsm.vcd", + gtkw_file="div_fsm.gtkw"): + + sim.add_clock(1e-6) + sim.add_process(gen_process) + sim.add_process(check_process) + sim.run() + + +if __name__ == "__main__": + unittest.main() -- 2.30.2