write-ok is expected to stay valid *after* being set,
[soc.git] / src / soc / fu / div / fsm.py
1 import enum
2 from nmigen import Elaboratable, Module, Signal, Shape, unsigned, Cat, Mux
3 from soc.fu.div.pipe_data import CoreInputData, CoreOutputData, DivPipeSpec
4 from nmutil.iocontrol import PrevControl, NextControl
5 from nmutil.singlepipe import ControlBase
6 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation
7
8
9 class FSMDivCoreConfig:
10 n_stages = 1
11 bit_width = 64
12 fract_width = 64
13
14
15 class FSMDivCoreInputData:
16 def __init__(self, core_config, reset_less=True):
17 self.core_config = core_config
18 self.dividend = Signal(128, reset_less=reset_less)
19 self.divisor_radicand = Signal(64, reset_less=reset_less)
20 self.operation = DivPipeCoreOperation.create_signal(
21 reset_less=reset_less)
22
23 def __iter__(self):
24 """ Get member signals. """
25 yield self.dividend
26 yield self.divisor_radicand
27 yield self.operation
28
29 def eq(self, rhs):
30 """ Assign member signals. """
31 return [self.dividend.eq(rhs.dividend),
32 self.divisor_radicand.eq(rhs.divisor_radicand),
33 self.operation.eq(rhs.operation),
34 ]
35
36
37 class FSMDivCoreOutputData:
38 def __init__(self, core_config, reset_less=True):
39 self.core_config = core_config
40 self.quotient_root = Signal(64, reset_less=reset_less)
41 self.remainder = Signal(3 * 64, reset_less=reset_less)
42
43 def __iter__(self):
44 """ Get member signals. """
45 yield self.quotient_root
46 yield self.remainder
47 return
48
49 def eq(self, rhs):
50 """ Assign member signals. """
51 return [self.quotient_root.eq(rhs.quotient_root),
52 self.remainder.eq(rhs.remainder)]
53
54
55 class DivStateNext(Elaboratable):
56 def __init__(self, quotient_width):
57 self.quotient_width = quotient_width
58 self.i = DivState(quotient_width=quotient_width, name="i")
59 self.divisor = Signal(quotient_width)
60 self.o = DivState(quotient_width=quotient_width, name="o")
61
62 def elaborate(self, platform):
63 m = Module()
64 difference = Signal(self.i.quotient_width * 2)
65 m.d.comb += difference.eq(self.i.dividend_quotient
66 - (self.divisor
67 << (self.quotient_width - 1)))
68 next_quotient_bit = Signal()
69 m.d.comb += next_quotient_bit.eq(
70 ~difference[self.quotient_width * 2 - 1])
71 value = Signal(self.i.quotient_width * 2)
72 with m.If(next_quotient_bit):
73 m.d.comb += value.eq(difference)
74 with m.Else():
75 m.d.comb += value.eq(self.i.dividend_quotient)
76
77 with m.If(self.i.done):
78 m.d.comb += self.o.eq(self.i)
79 with m.Else():
80 m.d.comb += [
81 self.o.q_bits_known.eq(self.i.q_bits_known + 1),
82 self.o.dividend_quotient.eq(Cat(next_quotient_bit, value))]
83 return m
84
85
86 class DivStateInit(Elaboratable):
87 def __init__(self, quotient_width):
88 self.quotient_width = quotient_width
89 self.dividend = Signal(quotient_width * 2)
90 self.o = DivState(quotient_width=quotient_width, name="o")
91
92 def elaborate(self, platform):
93 m = Module()
94 m.d.comb += self.o.q_bits_known.eq(0)
95 m.d.comb += self.o.dividend_quotient.eq(self.dividend)
96 return m
97
98
99 class DivState:
100 def __init__(self, quotient_width, name):
101 self.quotient_width = quotient_width
102 self.q_bits_known = Signal(range(1 + quotient_width),
103 name=name + "_q_bits_known")
104 self.dividend_quotient = Signal(unsigned(2 * quotient_width),
105 name=name + "_dividend_quotient")
106
107 @property
108 def done(self):
109 return self.will_be_done_after(steps=0)
110
111 def will_be_done_after(self, steps):
112 """ Returns 1 if this state will be done after
113 another `steps` passes through DivStateNext."""
114 assert isinstance(steps, int), "steps must be an integer"
115 assert steps >= 0
116 return self.q_bits_known >= max(0, self.quotient_width - steps)
117
118 @property
119 def quotient(self):
120 """ get the quotient -- requires self.done is True """
121 return self.dividend_quotient[0:self.quotient_width]
122
123 @property
124 def remainder(self):
125 """ get the remainder -- requires self.done is True """
126 return self.dividend_quotient[self.quotient_width:self.quotient_width*2]
127
128 def eq(self, rhs):
129 return [self.q_bits_known.eq(rhs.q_bits_known),
130 self.dividend_quotient.eq(rhs.dividend_quotient)]
131
132
133 class FSMDivCoreStage(ControlBase):
134 def __init__(self, pspec):
135 super().__init__()
136 self.pspec = pspec
137 self.p.i_data = CoreInputData(pspec)
138 self.n.o_data = CoreOutputData(pspec)
139 self.saved_input_data = CoreInputData(pspec)
140 self.empty = Signal(reset=1)
141 self.saved_state = DivState(64, name="saved_state")
142 self.div_state_next = DivStateNext(64)
143 self.div_state_init = DivStateInit(64)
144 self.divisor = Signal(unsigned(64))
145
146 def elaborate(self, platform):
147 m = super().elaborate(platform)
148 m.submodules.div_state_next = self.div_state_next
149 m.submodules.div_state_init = self.div_state_init
150 i_data = self.p.i_data
151 o_data = self.n.o_data
152 core_i = i_data.core
153 core_o = o_data.core
154
155 core_saved_i = self.saved_input_data.core
156
157 # TODO: handle cancellation
158
159 m.d.comb += self.div_state_init.dividend.eq(core_i.dividend)
160
161 m.d.comb += o_data.eq_without_core(self.saved_input_data)
162 m.d.comb += core_o.quotient_root.eq(self.div_state_next.o.quotient)
163 # fract width of `DivPipeCoreOutputData.remainder`
164 remainder_fract_width = 64 * 3
165 # fract width of `DivPipeCoreInputData.dividend`
166 dividend_fract_width = 64 * 2
167 rem_start = remainder_fract_width - dividend_fract_width
168 m.d.comb += core_o.remainder.eq(self.div_state_next.o.remainder
169 << rem_start)
170 m.d.comb += self.n.o_valid.eq(
171 ~self.empty & self.saved_state.will_be_done_after(1))
172 m.d.comb += self.p.o_ready.eq(self.empty)
173 m.d.sync += self.saved_state.eq(self.div_state_next.o)
174
175 with m.If(self.empty):
176 m.d.comb += self.div_state_next.i.eq(self.div_state_init.o)
177 m.d.comb += self.div_state_next.divisor.eq(core_i.divisor_radicand)
178 with m.If(self.p.i_valid):
179 m.d.sync += self.empty.eq(0)
180 m.d.sync += self.saved_input_data.eq(i_data)
181 with m.Else():
182 m.d.comb += [
183 self.div_state_next.i.eq(self.saved_state),
184 self.div_state_next.divisor.eq(core_saved_i.divisor_radicand)]
185 with m.If(self.n.i_ready & self.n.o_valid):
186 m.d.sync += self.empty.eq(1)
187
188 return m
189
190 def __iter__(self):
191 yield from self.p
192 yield from self.n
193
194 def ports(self):
195 return list(self)