add div fsm core (`DivState*`) with tests
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 18 Jul 2020 03:16:27 +0000 (20:16 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 18 Jul 2020 03:16:27 +0000 (20:16 -0700)
comb test works
fsm test fails for some reason

src/soc/fu/div/fsm.py
src/soc/fu/div/pipeline.py
src/soc/fu/div/test/test_fsm.py [new file with mode: 0644]

index bae1a43e625c43cbdca504502a2ecce15cd67495..10bc80d34c2816c69c01056562d9e1ca0292585b 100644 (file)
@@ -1,6 +1,8 @@
 import enum
-from nmigen import Elaboratable, Module, Signal
-from soc.fu.div.pipe_data import CoreInputData, CoreOutputData
+from nmigen import Elaboratable, Module, Signal, Shape, unsigned, Cat, Mux
+from soc.fu.div.pipe_data import CoreInputData, CoreOutputData, DivPipeSpec
+from nmutil.iocontrol import PrevControl, NextControl
+from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation
 
 
 class FSMDivCoreConfig:
@@ -49,51 +51,130 @@ class FSMDivCoreOutputData:
                 self.remainder.eq(rhs.remainder)]
 
 
-class FSMDivCorePrev:
+class FSMDivCorePrevControl(PrevControl):
+    data_i: CoreInputData
+
     def __init__(self, pspec):
+        super().__init__(stage_ctl=True, maskwid=pspec.id_wid)
+        self.pspec = pspec
         self.data_i = CoreInputData(pspec)
-        self.valid_i = Signal()
-        self.ready_o = Signal()
 
-    def __iter__(self):
-        yield from self.data_i
-        yield self.valid_i
-        yield self.ready_o
 
+class FSMDivCoreNextControl(NextControl):
+    data_o: CoreOutputData
 
-class FSMDivCoreNext:
     def __init__(self, pspec):
+        super().__init__(stage_ctl=True, maskwid=pspec.id_wid)
+        self.pspec = pspec
         self.data_o = CoreOutputData(pspec)
-        self.valid_o = Signal()
-        self.ready_i = Signal()
 
-    def __iter__(self):
-        yield from self.data_o
-        yield self.valid_o
-        yield self.ready_i
 
+class DivStateNext(Elaboratable):
+    def __init__(self, quotient_width):
+        self.quotient_width = quotient_width
+        self.i = DivState(quotient_width=quotient_width, name="i")
+        self.divisor = Signal(quotient_width)
+        self.o = DivState(quotient_width=quotient_width, name="o")
+
+    def elaborate(self, platform):
+        m = Module()
+        difference = Signal(self.i.quotient_width * 2)
+        m.d.comb += difference.eq(self.i.dividend_quotient
+                                  - (self.divisor
+                                     << (self.quotient_width - 1)))
+        next_quotient_bit = Signal()
+        m.d.comb += next_quotient_bit.eq(
+            ~difference[self.quotient_width * 2 - 1])
+        value = Signal(self.i.quotient_width * 2)
+        with m.If(next_quotient_bit):
+            m.d.comb += value.eq(difference)
+        with m.Else():
+            m.d.comb += value.eq(self.i.dividend_quotient)
+
+        with m.If(self.i.done):
+            m.d.comb += self.o.eq(self.i)
+        with m.Else():
+            m.d.comb += [
+                self.o.q_bits_known.eq(self.i.q_bits_known + 1),
+                self.o.dividend_quotient.eq(Cat(next_quotient_bit, value))]
+        return m
+
+
+class DivStateInit(Elaboratable):
+    def __init__(self, quotient_width):
+        self.quotient_width = quotient_width
+        self.dividend = Signal(quotient_width * 2)
+        self.o = DivState(quotient_width=quotient_width, name="o")
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.o.q_bits_known.eq(0)
+        m.d.comb += self.o.dividend_quotient.eq(self.dividend)
+        return m
+
+
+class DivState:
+    def __init__(self, quotient_width, name):
+        self.quotient_width = quotient_width
+        self.q_bits_known = Signal(range(1 + quotient_width),
+                                   name=name + "_q_bits_known")
+        self.dividend_quotient = Signal(unsigned(2 * quotient_width),
+                                        name=name + "_dividend_quotient")
 
-class DivState(enum.Enum):
-    Empty = 0
-    Computing = 1
-    WaitingOnOutput = 2
+    @property
+    def done(self):
+        return self.q_bits_known == self.quotient_width
+
+    @property
+    def quotient(self):
+        """ get the quotient -- requires self.done is True """
+        return self.dividend_quotient[0:self.quotient_width]
+
+    @property
+    def remainder(self):
+        """ get the remainder -- requires self.done is True """
+        return self.dividend_quotient[self.quotient_width:self.quotient_width*2]
+
+    def eq(self, rhs):
+        return [self.q_bits_known.eq(rhs.q_bits_known),
+                self.dividend_quotient.eq(rhs.dividend_quotient)]
 
 
 class FSMDivCoreStage(Elaboratable):
-    def __init__(self, pspec):
-        self.p = FSMDivCorePrev(pspec)
-        self.n = FSMDivCoreNext(pspec)
+    def __init__(self, pspec: DivPipeSpec):
+        self.pspec = pspec
+        self.p = FSMDivCorePrevControl(pspec)
+        self.n = FSMDivCoreNextControl(pspec)
         self.saved_input_data = CoreInputData(pspec)
         self.canceled = Signal()
-        self.state = Signal(DivState, reset=DivState.Empty)
+        self.empty = Signal(reset=1)
+        self.saved_state = DivState(64)
 
     def elaborate(self, platform):
         m = Module()
+        m.submodules.p = self.p
+        m.submodules.n = self.n
+        data_i = self.p.data_i
+        data_o = self.p.data_o
 
         # TODO: calculate self.canceled from self.p.data_i.ctx
         m.d.comb += self.canceled.eq(False)
 
-        # TODO(programmerjake): finish
+        # TODO: adapt to refactored DivState interface
+        fsm_state_in = DivState(64)
+        divisor = Signal(unsigned(64))
+        fsm_state_out = fsm_state_in.make_next_state(m, divisor)
+
+        with m.If(self.canceled):
+            with m.If(self.p.valid_i):
+                ...
+            with m.Else():
+                ...
+        with m.Else():
+            with m.If(self.p.valid_i):
+                ...
+            with m.Else():
+                ...
 
         return m
 
index 2ea291e996146e9992577f4307b724a92724d1e1..80f5a94b22e83927050060891279ffca495b0168 100644 (file)
@@ -6,18 +6,26 @@ from soc.fu.div.output_stage import DivOutputStage
 from soc.fu.div.setup_stage import DivSetupStage
 from soc.fu.div.core_stages import (DivCoreSetupStage, DivCoreCalculateStage,
                                     DivCoreFinalStage)
+from soc.fu.div.pipe_data import DivPipeKindConfigCombPipe
 
 
 class DivStagesStart(PipeModBaseChain):
     def get_chain(self):
         alu_input = DivMulInputStage(self.pspec)
         div_setup = DivSetupStage(self.pspec)
-        core_setup = DivCoreSetupStage(self.pspec)
-        return [alu_input, div_setup, core_setup]
+        if isinstance(self.pspec.div_pipe_kind.config,
+                      DivPipeKindConfigCombPipe):
+            core_setup = [DivCoreSetupStage(self.pspec)]
+        else:
+            core_setup = ()
+        return [alu_input, div_setup, *core_setup]
 
 
 class DivStagesMiddle(PipeModBaseChain):
     def __init__(self, pspec, stage_start_index, stage_end_index):
+        assert isinstance(pspec.div_pipe_kind.config,
+                          DivPipeKindConfigCombPipe),\
+            "DivStagesMiddle must be used with a DivPipeKindConfigCombPipe"
         self.stage_start_index = stage_start_index
         self.stage_end_index = stage_end_index
         super().__init__(pspec)
@@ -31,11 +39,15 @@ class DivStagesMiddle(PipeModBaseChain):
 
 class DivStagesEnd(PipeModBaseChain):
     def get_chain(self):
-        core_final = DivCoreFinalStage(self.pspec)
+        if isinstance(self.pspec.div_pipe_kind.config,
+                      DivPipeKindConfigCombPipe):
+            core_final = [DivCoreFinalStage(self.pspec)]
+        else:
+            core_final = ()
         div_out = DivOutputStage(self.pspec)
         alu_out = DivMulOutputStage(self.pspec)
         self.div_out = div_out  # debugging - bug #425
-        return [core_final, div_out, alu_out]
+        return [*core_final, div_out, alu_out]
 
 
 class DivBasePipe(ControlBase):
@@ -43,11 +55,16 @@ class DivBasePipe(ControlBase):
         ControlBase.__init__(self)
         self.pspec = pspec
         self.pipe_start = DivStagesStart(pspec)
-        compute_steps = pspec.core_config.n_stages
         self.pipe_middles = []
-        for start in range(0, compute_steps, compute_steps_per_stage):
-            end = min(start + compute_steps_per_stage, compute_steps)
-            self.pipe_middles.append(DivStagesMiddle(pspec, start, end))
+        if isinstance(self.pspec.div_pipe_kind.config,
+                      DivPipeKindConfigCombPipe):
+            compute_steps = pspec.core_config.n_stages
+            for start in range(0, compute_steps, compute_steps_per_stage):
+                end = min(start + compute_steps_per_stage, compute_steps)
+                self.pipe_middles.append(DivStagesMiddle(pspec, start, end))
+        else:
+            self.pipe_middles.append(
+                self.pspec.div_pipe_kind.config.core_stage_class(pspec))
         self.pipe_end = DivStagesEnd(pspec)
         self._eqs = self.connect([self.pipe_start,
                                   *self.pipe_middles,
diff --git a/src/soc/fu/div/test/test_fsm.py b/src/soc/fu/div/test/test_fsm.py
new file mode 100644 (file)
index 0000000..75f606f
--- /dev/null
@@ -0,0 +1,288 @@
+import unittest
+from soc.fu.div.fsm import DivState, DivStateInit, DivStateNext
+from nmigen import Elaboratable, Module, Signal, unsigned
+from nmigen.cli import rtlil
+from nmigen.sim.pysim import Simulator, Delay, Tick
+
+
+class CheckEvent(Elaboratable):
+    """helper to add indication to vcd when signals are checked"""
+
+    def __init__(self):
+        self.event = Signal()
+
+    def trigger(self):
+        yield self.event.eq(~self.event)
+
+    def elaborate(self, platform):
+        m = Module()
+        # use event somehow so nmigen simulation knows about it
+        m.d.comb += Signal().eq(self.event)
+        return m
+
+
+class DivStateCombTest(Elaboratable):
+    """Test stringing a bunch of copies of the FSM state-function together"""
+
+    def __init__(self, quotient_width):
+        self.check_event = CheckEvent()
+        self.quotient_width = quotient_width
+        self.dividend = Signal(unsigned(quotient_width * 2))
+        self.divisor = Signal(unsigned(quotient_width))
+        self.quotient = Signal(unsigned(quotient_width))
+        self.remainder = Signal(unsigned(quotient_width))
+        self.expected_quotient = Signal(unsigned(quotient_width))
+        self.expected_remainder = Signal(unsigned(quotient_width))
+        self.expected_valid = Signal()
+        self.states = []
+        for i in range(quotient_width + 1):
+            state = DivState(quotient_width=quotient_width, name=f"state{i}")
+            self.states.append(state)
+        self.init = DivStateInit(quotient_width)
+        self.nexts = []
+        for i in range(quotient_width):
+            next = DivStateNext(quotient_width)
+            self.nexts.append(next)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.submodules.check_event = self.check_event
+        m.submodules.init = self.init
+        m.d.comb += self.init.dividend.eq(self.dividend)
+        m.d.comb += self.states[0].eq(self.init.o)
+        last_state = self.states[0]
+        for i in range(self.quotient_width):
+            setattr(m.submodules, f"next{i}", self.nexts[i])
+            m.d.comb += self.nexts[i].divisor.eq(self.divisor)
+            m.d.comb += self.nexts[i].i.eq(last_state)
+            last_state = self.states[i + 1]
+            m.d.comb += last_state.eq(self.nexts[i].o)
+        m.d.comb += self.quotient.eq(last_state.quotient)
+        m.d.comb += self.remainder.eq(last_state.remainder)
+        m.d.comb += self.expected_valid.eq(
+            (self.dividend < (self.divisor << self.quotient_width))
+            & (self.divisor != 0))
+        with m.If(self.expected_valid):
+            m.d.comb += self.expected_quotient.eq(
+                self.dividend // self.divisor)
+            m.d.comb += self.expected_remainder.eq(
+                self.dividend % self.divisor)
+        return m
+
+
+class DivStateFSMTest(Elaboratable):
+    def __init__(self, quotient_width):
+        self.check_done_event = CheckEvent()
+        self.check_event = CheckEvent()
+        self.quotient_width = quotient_width
+        self.dividend = Signal(unsigned(quotient_width * 2))
+        self.divisor = Signal(unsigned(quotient_width))
+        self.quotient = Signal(unsigned(quotient_width))
+        self.remainder = Signal(unsigned(quotient_width))
+        self.expected_quotient = Signal(unsigned(quotient_width))
+        self.expected_remainder = Signal(unsigned(quotient_width))
+        self.expected_valid = Signal()
+        self.state = DivState(quotient_width=quotient_width,
+                              name="state")
+        self.next_state = DivState(quotient_width=quotient_width,
+                                   name="next_state")
+        self.init = DivStateInit(quotient_width)
+        self.next = DivStateNext(quotient_width)
+        self.state_done = Signal()
+        self.next_state_done = Signal()
+        self.clear = Signal(reset=1)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.submodules.check_event = self.check_event
+        m.submodules.check_done_event = self.check_done_event
+        m.submodules.init = self.init
+        m.submodules.next = self.next
+        m.d.comb += self.init.dividend.eq(self.dividend)
+        m.d.comb += self.next.divisor.eq(self.divisor)
+        m.d.comb += self.quotient.eq(self.state.quotient)
+        m.d.comb += self.remainder.eq(self.state.remainder)
+        m.d.comb += self.next.i.eq(self.state)
+        m.d.comb += self.state_done.eq(self.state.done)
+        m.d.comb += self.next_state_done.eq(self.next_state.done)
+
+        with m.If(self.state.done | self.clear):
+            m.d.comb += self.next_state.eq(self.init.o)
+        with m.Else():
+            m.d.comb += self.next_state.eq(self.next.o)
+
+        m.d.sync += self.state.eq(self.next_state)
+
+        m.d.comb += self.expected_valid.eq(
+            (self.dividend < (self.divisor << self.quotient_width))
+            & (self.divisor != 0))
+        with m.If(self.expected_valid):
+            m.d.comb += self.expected_quotient.eq(
+                self.dividend // self.divisor)
+            m.d.comb += self.expected_remainder.eq(
+                self.dividend % self.divisor)
+        return m
+
+
+def get_cases(quotient_width):
+    test_cases = []
+    mask = ~(~0 << quotient_width)
+    for i in range(-3, 4):
+        test_cases.append(i & mask)
+    for i in [-1, 0, 1]:
+        test_cases.append((i + (mask >> 1)) & mask)
+    test_cases.sort()
+    return test_cases
+
+
+class TestDivState(unittest.TestCase):
+    def test_div_state_comb(self, quotient_width=8):
+        test_cases = get_cases(quotient_width)
+        mask = ~(~0 << quotient_width)
+        dut = DivStateCombTest(quotient_width)
+        vl = rtlil.convert(dut,
+                           ports=[dut.dividend,
+                                  dut.divisor,
+                                  dut.quotient,
+                                  dut.remainder])
+        with open("div_fsm_comb_pipeline.il", "w") as f:
+            f.write(vl)
+        dut = DivStateCombTest(quotient_width)
+
+        def check(dividend, divisor):
+            with self.subTest(dividend=f"{dividend:#x}",
+                              divisor=f"{divisor:#x}"):
+                yield from dut.check_event.trigger()
+                for i in range(quotient_width + 1):
+                    # done must be correct and eventually true
+                    # even if a div-by-zero or overflow occurred
+                    done = yield dut.states[i].done
+                    self.assertEqual(done, i == quotient_width)
+                if divisor != 0:
+                    quotient = dividend // divisor
+                    remainder = dividend % divisor
+                    if quotient <= mask:
+                        with self.subTest(quotient=f"{quotient:#x}",
+                                          remainder=f"{remainder:#x}"):
+                            self.assertTrue((yield dut.expected_valid))
+                            self.assertEqual((yield dut.expected_quotient), quotient)
+                            self.assertEqual((yield dut.expected_remainder), remainder)
+                            self.assertEqual((yield dut.quotient), quotient)
+                            self.assertEqual((yield dut.remainder), remainder)
+                    else:
+                        self.assertFalse((yield dut.expected_valid))
+                else:
+                    self.assertFalse((yield dut.expected_valid))
+
+        def process(gen):
+            for dividend_high in test_cases:
+                for dividend_low in test_cases:
+                    dividend = dividend_low + \
+                        (dividend_high << quotient_width)
+                    for divisor in test_cases:
+                        if gen:
+                            yield Delay(0.5e-6)
+                            yield dut.dividend.eq(dividend)
+                            yield dut.divisor.eq(divisor)
+                            yield Delay(0.5e-6)
+                        else:
+                            yield Delay(1e-6)
+                            yield from check(dividend, divisor)
+
+        def gen_process():
+            yield from process(gen=True)
+
+        def check_process():
+            yield from process(gen=False)
+
+        sim = Simulator(dut)
+        with sim.write_vcd(vcd_file="div_fsm_comb_pipeline.vcd",
+                           gtkw_file="div_fsm_comb_pipeline.gtkw"):
+
+            sim.add_process(gen_process)
+            sim.add_process(check_process)
+            sim.run()
+
+    def test_div_state_fsm(self, quotient_width=8):
+        # TODO(programmerjake): fix test: for some reason
+        # the check process is delayed to the second division
+        # before it tries to do the first check
+        test_cases = get_cases(quotient_width)
+        mask = ~(~0 << quotient_width)
+        dut = DivStateFSMTest(quotient_width)
+        vl = rtlil.convert(dut,
+                           ports=[dut.dividend,
+                                  dut.divisor,
+                                  dut.quotient,
+                                  dut.remainder])
+        with open("div_fsm.il", "w") as f:
+            f.write(vl)
+
+        def check(dividend, divisor):
+            with self.subTest(dividend=f"{dividend:#x}",
+                              divisor=f"{divisor:#x}"):
+                for i in range(quotient_width + 1):
+                    yield Tick()
+                    yield Delay(0.1e-6)
+                    yield from dut.check_done_event.trigger()
+                    # done must be correct and eventually true
+                    # even if a div-by-zero or overflow occurred
+                    done = yield dut.state.done
+                    self.assertEqual(done, i == quotient_width)
+                yield from dut.check_event.trigger()
+                if divisor != 0:
+                    quotient = dividend // divisor
+                    remainder = dividend % divisor
+                    if quotient <= mask:
+                        with self.subTest(quotient=f"{quotient:#x}",
+                                          remainder=f"{remainder:#x}"):
+                            self.assertTrue((yield dut.expected_valid))
+                            self.assertEqual((yield dut.expected_quotient), quotient)
+                            self.assertEqual((yield dut.expected_remainder), remainder)
+                            self.assertEqual((yield dut.quotient), quotient)
+                            self.assertEqual((yield dut.remainder), remainder)
+                    else:
+                        self.assertFalse((yield dut.expected_valid))
+                else:
+                    self.assertFalse((yield dut.expected_valid))
+
+        def process(gen):
+            if gen:
+                yield dut.clear.eq(1)
+            else:
+                yield from dut.check_event.trigger()
+                yield from dut.check_done_event.trigger()
+            yield Tick()
+            for dividend_high in test_cases:
+                for dividend_low in test_cases:
+                    dividend = dividend_low + \
+                        (dividend_high << quotient_width)
+                    for divisor in test_cases:
+                        if gen:
+                            yield Delay(0.2e-6)
+                            yield dut.clear.eq(0)
+                            yield dut.dividend.eq(dividend)
+                            yield dut.divisor.eq(divisor)
+                            for _ in range(quotient_width):
+                                yield Tick()
+                        else:
+                            yield from check(dividend, divisor)
+
+        def gen_process():
+            yield from process(gen=True)
+
+        def check_process():
+            yield from process(gen=False)
+
+        sim = Simulator(dut)
+        with sim.write_vcd(vcd_file="div_fsm.vcd",
+                           gtkw_file="div_fsm.gtkw"):
+
+            sim.add_clock(1e-6)
+            sim.add_process(gen_process)
+            sim.add_process(check_process)
+            sim.run()
+
+
+if __name__ == "__main__":
+    unittest.main()