implement CLDivRemFSMStage
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
index f6ca4e83fe85887aae425ad4b523f9801c41500b..bff3676a2a069eeb6db3a6d8d6815bb71c659831 100644 (file)
@@ -9,9 +9,11 @@
 https://bugs.libre-soc.org/show_bug.cgi?id=784
 """
 
+from dataclasses import dataclass, field, fields
 from nmigen.hdl.ir import Elaboratable
-from nmigen.hdl.ast import Signal
+from nmigen.hdl.ast import Signal, Value
 from nmigen.hdl.dsl import Module
+from nmutil.singlepipe import ControlBase
 
 
 def equal_leading_zero_count_reference(a, b, width):
@@ -104,4 +106,241 @@ class EqualLeadingZeroCount(Elaboratable):
 
         return m
 
-# TODO: add CLDivRem
+
+@dataclass(frozen=True, unsafe_hash=True)
+class CLDivRemShape:
+    width: int
+    n_width: int
+
+    def __post_init__(self):
+        assert self.n_width >= self.width > 0
+
+    @property
+    def done_step(self):
+        return self.width
+
+    @property
+    def step_range(self):
+        return range(self.done_step + 1)
+
+
+@dataclass(frozen=True, eq=False)
+class CLDivRemState:
+    shape: CLDivRemShape
+    name: str
+    d: Signal = field(init=False)
+    r: Signal = field(init=False)
+    q: Signal = field(init=False)
+    step: Signal = field(init=False)
+
+    def __init__(self, shape, *, name=None, src_loc_at=0):
+        assert isinstance(shape, CLDivRemShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        d = Signal(2 * shape.width, name=f"{name}_d")
+        r = Signal(shape.n_width, name=f"{name}_r")
+        q = Signal(shape.width, name=f"{name}_q")
+        step = Signal(shape.width, name=f"{name}_step")
+        object.__setattr__(self, "shape", shape)
+        object.__setattr__(self, "name", name)
+        object.__setattr__(self, "d", d)
+        object.__setattr__(self, "r", r)
+        object.__setattr__(self, "q", q)
+        object.__setattr__(self, "step", step)
+
+    def eq(self, rhs):
+        assert isinstance(rhs, CLDivRemState)
+        for f in fields(CLDivRemState):
+            if f.name in ("shape", "name"):
+                continue
+            l = getattr(self, f.name)
+            r = getattr(rhs, f.name)
+            yield l.eq(r)
+
+    @staticmethod
+    def like(other, *, name=None, src_loc_at=0):
+        assert isinstance(other, CLDivRemState)
+        return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
+
+    @property
+    def done(self):
+        return self.will_be_done_after(steps=0)
+
+    def will_be_done_after(self, steps):
+        """ Returns True if this state will be done after
+            another `steps` passes through `set_to_next`."""
+        assert isinstance(steps, int) and steps >= 0
+        return self.step >= max(0, self.shape.done_step - steps)
+
+    def set_to_initial(self, m, n, d):
+        assert isinstance(m, Module)
+        m.d.comb += [
+            self.d.eq(Value.cast(d) << self.shape.width),
+            self.r.eq(n),
+            self.q.eq(0),
+            self.step.eq(0),
+        ]
+
+    def set_to_next(self, m, state_in):
+        assert isinstance(m, Module)
+        assert isinstance(state_in, CLDivRemState)
+        assert state_in.shape == self.shape
+        assert self is not state_in, "a.set_to_next(m, a) is not allowed"
+
+        equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
+        # can't name submodule since it would conflict if this function is
+        # called multiple times in a Module
+        m.submodules += equal_leading_zero_count
+
+        with m.If(state_in.done):
+            m.d.comb += self.eq(state_in)
+        with m.Else():
+            m.d.comb += [
+                self.step.eq(state_in.step + 1),
+                self.d.eq(state_in.d >> 1),
+                equal_leading_zero_count.a.eq(self.d),
+                equal_leading_zero_count.b.eq(state_in.r),
+            ]
+            d_top = self.d[self.shape.n_width:]
+            with m.If(equal_leading_zero_count.out & (d_top == 0)):
+                m.d.comb += [
+                    self.r.eq(state_in.r ^ self.d),
+                    self.q.eq((state_in.q << 1) | 1),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    self.r.eq(state_in.r),
+                    self.q.eq(state_in.q << 1),
+                ]
+
+
+class CLDivRemInputData:
+    def __init__(self, shape):
+        assert isinstance(shape, CLDivRemShape)
+        self.shape = shape
+        self.n = Signal(shape.n_width)
+        self.d = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.n
+        yield self.d
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [
+            self.n.eq(rhs.n),
+            self.d.eq(rhs.d),
+        ]
+
+
+class CLDivRemOutputData:
+    def __init__(self, shape):
+        assert isinstance(shape, CLDivRemShape)
+        self.shape = shape
+        self.q = Signal(shape.width)
+        self.r = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.q
+        yield self.r
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [
+            self.q.eq(rhs.q),
+            self.r.eq(rhs.r),
+        ]
+
+
+class CLDivRemFSMStage(ControlBase):
+    """carry-less div/rem
+
+    Attributes:
+    shape: CLDivRemShape
+        the shape
+    steps_per_clock: int
+        number of steps that should be taken per clock cycle
+    in_valid: Signal()
+        input. true when the data inputs (`n` and `d`) are valid.
+        data transfer in occurs when `in_valid & in_ready`.
+    in_ready: Signal()
+        output. true when this FSM is ready to accept input.
+        data transfer in occurs when `in_valid & in_ready`.
+    n: Signal(shape.n_width)
+        numerator in, the value must be small enough that `q` and `r` don't
+        overflow. having `n_width == width` is sufficient.
+    d: Signal(shape.width)
+        denominator in, must be non-zero.
+    q: Signal(shape.width)
+        quotient out.
+    r: Signal(shape.width)
+        remainder out.
+    out_valid: Signal()
+        output. true when the data outputs (`q` and `r`) are valid
+        (or are junk because the inputs were out of range).
+        data transfer out occurs when `out_valid & out_ready`.
+    out_ready: Signal()
+        input. true when the output can be read.
+        data transfer out occurs when `out_valid & out_ready`.
+    """
+
+    def __init__(self, pspec, shape, *, steps_per_clock=4):
+        assert isinstance(shape, CLDivRemShape)
+        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
+        self.shape = shape
+        self.steps_per_clock = steps_per_clock
+        self.pspec = pspec  # store now: used in ispec and ospec
+        super().__init__(stage=self)
+        self.empty = Signal(reset=1)
+        self.saved_state = CLDivRemState(shape)
+
+    def ispec(self):
+        return CLDivRemInputData(self.shape)
+
+    def ospec(self):
+        return CLDivRemOutputData(self.shape)
+
+    def setup(self, m, i):
+        pass
+
+    def elaborate(self, platform):
+        m = super().elaborate(platform)
+        i_data: CLDivRemInputData = self.p.i_data
+        o_data: CLDivRemOutputData = self.n.o_data
+
+        # TODO: handle cancellation
+
+        state_will_be_done = self.saved_state.will_be_done_after(
+            self.steps_per_clock)
+        m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
+        m.d.comb += self.p.o_ready.eq(self.empty)
+
+        def make_nc(i):
+            return CLDivRemState(self.shape, name=f"next_chain_{i}")
+        next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
+        for i in range(self.steps_per_clock):
+            next_chain[i + 1].set_to_next(m, next_chain[i])
+        m.d.sync += self.saved_state.eq(next_chain[-1])
+        m.d.comb += o_data.q.eq(next_chain[-1].q)
+        m.d.comb += o_data.r.eq(next_chain[-1].r)
+
+        with m.If(self.empty):
+            next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
+            with m.If(self.p.i_valid):
+                m.d.sync += self.empty.eq(0)
+        with m.Else():
+            m.d.comb += next_chain[0].eq(self.saved_state)
+            with m.If(self.n.i_ready & self.n.o_valid):
+                m.d.sync += self.empty.eq(1)
+
+        return m
+
+    def __iter__(self):
+        yield from self.p
+        yield from self.n
+
+    def ports(self):
+        return list(self)