too much debug info going past, so add the test registers to the
[soc.git] / src / soc / fu / div / fsm.py
1 import enum
2 from nmigen import Elaboratable, Module, Signal, Shape, unsigned, Cat, Mux
3 from soc.fu.div.pipe_data import CoreInputData, CoreOutputData, DivPipeSpec
4 from nmutil.iocontrol import PrevControl, NextControl
5 from nmutil.singlepipe import ControlBase
6 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation
7
8
9 class FSMDivCoreConfig:
10 n_stages = 1
11 bit_width = 64
12 fract_width = 64
13
14
15 class FSMDivCoreInputData:
16 def __init__(self, core_config, reset_less=True):
17 self.core_config = core_config
18 self.dividend = Signal(128, reset_less=reset_less)
19 self.divisor_radicand = Signal(64, reset_less=reset_less)
20 self.operation = DivPipeCoreOperation.create_signal(
21 reset_less=reset_less)
22
23 def __iter__(self):
24 """ Get member signals. """
25 yield self.dividend
26 yield self.divisor_radicand
27 yield self.operation
28
29 def eq(self, rhs):
30 """ Assign member signals. """
31 return [self.dividend.eq(rhs.dividend),
32 self.divisor_radicand.eq(rhs.divisor_radicand),
33 self.operation.eq(rhs.operation),
34 ]
35
36
37 class FSMDivCoreOutputData:
38 def __init__(self, core_config, reset_less=True):
39 self.core_config = core_config
40 self.quotient_root = Signal(64, reset_less=reset_less)
41 self.remainder = Signal(3 * 64, reset_less=reset_less)
42
43 def __iter__(self):
44 """ Get member signals. """
45 yield self.quotient_root
46 yield self.remainder
47 return
48
49 def eq(self, rhs):
50 """ Assign member signals. """
51 return [self.quotient_root.eq(rhs.quotient_root),
52 self.remainder.eq(rhs.remainder)]
53
54
55 class FSMDivCorePrevControl(PrevControl):
56 data_i: CoreInputData
57
58 def __init__(self, pspec):
59 super().__init__()
60 self.pspec = pspec
61 self.data_i = CoreInputData(pspec)
62
63
64 class FSMDivCoreNextControl(NextControl):
65 data_o: CoreOutputData
66
67 def __init__(self, pspec):
68 super().__init__()
69 self.pspec = pspec
70 self.data_o = CoreOutputData(pspec)
71
72
73 class DivStateNext(Elaboratable):
74 def __init__(self, quotient_width):
75 self.quotient_width = quotient_width
76 self.i = DivState(quotient_width=quotient_width, name="i")
77 self.divisor = Signal(quotient_width)
78 self.o = DivState(quotient_width=quotient_width, name="o")
79
80 def elaborate(self, platform):
81 m = Module()
82 difference = Signal(self.i.quotient_width * 2)
83 m.d.comb += difference.eq(self.i.dividend_quotient
84 - (self.divisor
85 << (self.quotient_width - 1)))
86 next_quotient_bit = Signal()
87 m.d.comb += next_quotient_bit.eq(
88 ~difference[self.quotient_width * 2 - 1])
89 value = Signal(self.i.quotient_width * 2)
90 with m.If(next_quotient_bit):
91 m.d.comb += value.eq(difference)
92 with m.Else():
93 m.d.comb += value.eq(self.i.dividend_quotient)
94
95 with m.If(self.i.done):
96 m.d.comb += self.o.eq(self.i)
97 with m.Else():
98 m.d.comb += [
99 self.o.q_bits_known.eq(self.i.q_bits_known + 1),
100 self.o.dividend_quotient.eq(Cat(next_quotient_bit, value))]
101 return m
102
103
104 class DivStateInit(Elaboratable):
105 def __init__(self, quotient_width):
106 self.quotient_width = quotient_width
107 self.dividend = Signal(quotient_width * 2)
108 self.o = DivState(quotient_width=quotient_width, name="o")
109
110 def elaborate(self, platform):
111 m = Module()
112 m.d.comb += self.o.q_bits_known.eq(0)
113 m.d.comb += self.o.dividend_quotient.eq(self.dividend)
114 return m
115
116
117 class DivState:
118 def __init__(self, quotient_width, name):
119 self.quotient_width = quotient_width
120 self.q_bits_known = Signal(range(1 + quotient_width),
121 name=name + "_q_bits_known")
122 self.dividend_quotient = Signal(unsigned(2 * quotient_width),
123 name=name + "_dividend_quotient")
124
125 @property
126 def done(self):
127 return self.q_bits_known == self.quotient_width
128
129 @property
130 def quotient(self):
131 """ get the quotient -- requires self.done is True """
132 return self.dividend_quotient[0:self.quotient_width]
133
134 @property
135 def remainder(self):
136 """ get the remainder -- requires self.done is True """
137 return self.dividend_quotient[self.quotient_width:self.quotient_width*2]
138
139 def eq(self, rhs):
140 return [self.q_bits_known.eq(rhs.q_bits_known),
141 self.dividend_quotient.eq(rhs.dividend_quotient)]
142
143
144 class FSMDivCoreStage(ControlBase):
145 def __init__(self, pspec: DivPipeSpec):
146 super().__init__()
147 self.pspec = pspec
148 # override p and n
149 self.p = FSMDivCorePrevControl(pspec)
150 self.n = FSMDivCoreNextControl(pspec)
151 self.saved_input_data = CoreInputData(pspec)
152 self.empty = Signal(reset=1)
153 self.saved_state = DivState(64, name="saved_state")
154 self.div_state_next = DivStateNext(64)
155 self.div_state_init = DivStateInit(64)
156 self.divisor = Signal(unsigned(64))
157
158 def elaborate(self, platform):
159 m = super().elaborate(platform)
160 m.submodules.div_state_next = self.div_state_next
161 m.submodules.div_state_init = self.div_state_init
162 data_i = self.p.data_i
163 core_i: FSMDivCoreInputData = data_i.core
164 data_o = self.n.data_o
165 core_o: FSMDivCoreOutputData = data_o.core
166 core_saved_i: FSMDivCoreInputData = self.saved_input_data.core
167
168 # TODO: handle cancellation
169
170 m.d.comb += self.div_state_init.dividend.eq(core_i.dividend)
171
172 m.d.comb += data_o.eq_without_core(self.saved_input_data)
173 m.d.comb += core_o.quotient_root.eq(self.div_state_next.o.quotient)
174 m.d.comb += core_o.remainder.eq(self.div_state_next.o.remainder)
175 m.d.comb += self.n.valid_o.eq(~self.empty & self.div_state_next.o.done)
176 m.d.comb += self.p.ready_o.eq(self.empty)
177 m.d.sync += self.saved_state.eq(self.div_state_next.o)
178
179 with m.If(self.empty):
180 m.d.comb += self.div_state_next.i.eq(self.div_state_init.o)
181 m.d.comb += self.div_state_next.divisor.eq(core_i.divisor_radicand)
182 with m.If(self.p.valid_i):
183 m.d.sync += self.empty.eq(0)
184 m.d.sync += self.saved_input_data.eq(data_i)
185 with m.Else():
186 m.d.comb += [
187 self.div_state_next.i.eq(self.saved_state),
188 self.div_state_next.divisor.eq(core_saved_i.divisor_radicand)]
189 with m.If(self.n.ready_i & self.n.valid_o):
190 m.d.sync += self.empty.eq(1)
191
192 return m
193
194 def __iter__(self):
195 yield from self.p
196 yield from self.n
197
198 def ports(self):
199 return list(self)