change all uses of dataclass to plain_data
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
22 from nmigen.lib.coding import PriorityEncoder
23 import enum
24
25
26 class DivPipeCoreConfig:
27 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
28
29 :attribute bit_width: base bit-width.
30 :attribute fract_width: base fract-width. Specifies location of base-2
31 radix point.
32 :attribute log2_radix: number of bits of ``quotient_root`` that should be
33 computed per pipeline stage.
34 """
35
36 def __init__(self, bit_width, fract_width, log2_radix):
37 """ Create a ``DivPipeCoreConfig`` instance. """
38 self.bit_width = bit_width
39 self.fract_width = fract_width
40 self.log2_radix = log2_radix
41 print(f"{self}: n_stages={self.n_stages}")
42
43 def __repr__(self):
44 """ Get repr. """
45 return f"DivPipeCoreConfig({self.bit_width}, " \
46 + f"{self.fract_width}, {self.log2_radix})"
47
48 @property
49 def n_stages(self):
50 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
51 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
52
53
54 class DivPipeCoreOperation(enum.Enum):
55 """ Operation for ``DivPipeCore``.
56
57 :attribute UDivRem: unsigned divide/remainder.
58 :attribute SqrtRem: square-root/remainder.
59 :attribute RSqrtRem: reciprocal-square-root/remainder.
60 """
61
62 SqrtRem = 0
63 UDivRem = 1
64 RSqrtRem = 2
65
66 def __int__(self):
67 """ Convert to int. """
68 return self.value
69
70 @classmethod
71 def create_signal(cls, *, src_loc_at=0, **kwargs):
72 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
73 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
74 src_loc_at=(src_loc_at + 1),
75 decoder=lambda v: str(cls(v)),
76 **kwargs)
77
78
79 DP = DivPipeCoreOperation
80
81
82 class DivPipeCoreInputData:
83 """ input data type for ``DivPipeCore``.
84
85 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
86 configuration to be used.
87 :attribute dividend: dividend for div/rem. Signal with a bit-width of
88 ``core_config.bit_width + core_config.fract_width`` and a fract-width
89 of ``core_config.fract_width * 2`` bits.
90 :attribute divisor_radicand: divisor for div/rem and radicand for
91 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
92 fract-width of ``core_config.fract_width`` bits.
93 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
94 """
95
96 def __init__(self, core_config, reset_less=True):
97 """ Create a ``DivPipeCoreInputData`` instance. """
98 self.core_config = core_config
99 bw = core_config.bit_width
100 fw = core_config.fract_width
101 self.dividend = Signal(bw + fw, reset_less=reset_less)
102 self.divisor_radicand = Signal(bw, reset_less=reset_less)
103 self.operation = DP.create_signal(reset_less=reset_less)
104
105 def __iter__(self):
106 """ Get member signals. """
107 yield self.dividend
108 yield self.divisor_radicand
109 yield self.operation
110
111 def eq(self, rhs):
112 """ Assign member signals. """
113 return [self.dividend.eq(rhs.dividend),
114 self.divisor_radicand.eq(rhs.divisor_radicand),
115 self.operation.eq(rhs.operation),
116 ]
117
118
119 class DivPipeCoreInterstageData:
120 """ interstage data type for ``DivPipeCore``.
121
122 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
123 configuration to be used.
124 :attribute divisor_radicand: divisor for div/rem and radicand for
125 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
126 fract-width of ``core_config.fract_width`` bits.
127 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
128 :attribute quotient_root: the quotient or root part of the result of the
129 operation. Signal with a bit-width of ``core_config.bit_width`` and a
130 fract-width of ``core_config.fract_width`` bits.
131 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
132 Signal with a bit-width of ``core_config.bit_width * 2`` and a
133 fract-width of ``core_config.fract_width * 2`` bits.
134 :attribute compare_lhs: The left-hand-side of the comparison in the
135 equation to be solved. Signal with a bit-width of
136 ``core_config.bit_width * 3`` and a fract-width of
137 ``core_config.fract_width * 3`` bits.
138 :attribute compare_rhs: The right-hand-side of the comparison in the
139 equation to be solved. Signal with a bit-width of
140 ``core_config.bit_width * 3`` and a fract-width of
141 ``core_config.fract_width * 3`` bits.
142 """
143
144 def __init__(self, core_config, reset_less=True):
145 """ Create a ``DivPipeCoreInterstageData`` instance. """
146 self.core_config = core_config
147 bw = core_config.bit_width
148 self.divisor_radicand = Signal(bw, reset_less=reset_less)
149 self.operation = DP.create_signal(reset_less=reset_less)
150 self.quotient_root = Signal(bw, reset_less=reset_less)
151 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
152 self.compare_lhs = Signal(bw * 3, reset_less=reset_less)
153 self.compare_rhs = Signal(bw * 3, reset_less=reset_less)
154
155 def __iter__(self):
156 """ Get member signals. """
157 yield self.divisor_radicand
158 yield self.operation
159 yield self.quotient_root
160 yield self.root_times_radicand
161 yield self.compare_lhs
162 yield self.compare_rhs
163
164 def eq(self, rhs):
165 """ Assign member signals. """
166 return [self.divisor_radicand.eq(rhs.divisor_radicand),
167 self.operation.eq(rhs.operation),
168 self.quotient_root.eq(rhs.quotient_root),
169 self.root_times_radicand.eq(rhs.root_times_radicand),
170 self.compare_lhs.eq(rhs.compare_lhs),
171 self.compare_rhs.eq(rhs.compare_rhs)]
172
173
174 class DivPipeCoreOutputData:
175 """ output data type for ``DivPipeCore``.
176
177 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
178 configuration to be used.
179 :attribute quotient_root: the quotient or root part of the result of the
180 operation. Signal with a bit-width of ``core_config.bit_width`` and a
181 fract-width of ``core_config.fract_width`` bits.
182 :attribute remainder: the remainder part of the result of the operation.
183 Signal with a bit-width of ``core_config.bit_width * 3`` and a
184 fract-width of ``core_config.fract_width * 3`` bits.
185 """
186
187 def __init__(self, core_config, reset_less=True):
188 """ Create a ``DivPipeCoreOutputData`` instance. """
189 self.core_config = core_config
190 bw = core_config.bit_width
191 self.quotient_root = Signal(bw, reset_less=reset_less)
192 self.remainder = Signal(bw * 3, reset_less=reset_less)
193
194 def __iter__(self):
195 """ Get member signals. """
196 yield self.quotient_root
197 yield self.remainder
198 return
199
200 def eq(self, rhs):
201 """ Assign member signals. """
202 return [self.quotient_root.eq(rhs.quotient_root),
203 self.remainder.eq(rhs.remainder)]
204
205
206 class DivPipeCoreSetupStage(Elaboratable):
207 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
208
209 def __init__(self, core_config):
210 """ Create a ``DivPipeCoreSetupStage`` instance."""
211 self.core_config = core_config
212 self.i = self.ispec()
213 self.o = self.ospec()
214
215 def ispec(self):
216 """ Get the input spec for this pipeline stage."""
217 return DivPipeCoreInputData(self.core_config)
218
219 def ospec(self):
220 """ Get the output spec for this pipeline stage."""
221 return DivPipeCoreInterstageData(self.core_config)
222
223 def setup(self, m, i):
224 """ Pipeline stage setup. """
225 m.submodules.div_pipe_core_setup = self
226 m.d.comb += self.i.eq(i)
227
228 def process(self, i):
229 """ Pipeline stage process. """
230 return self.o # return processed data (ignore i)
231
232 def elaborate(self, platform):
233 """ Elaborate into ``Module``. """
234 m = Module()
235 comb = m.d.comb
236
237 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
238 comb += self.o.quotient_root.eq(0)
239 comb += self.o.root_times_radicand.eq(0)
240
241 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
242 fw = self.core_config.fract_width
243
244 with m.Switch(self.i.operation):
245 with m.Case(int(DP.UDivRem)):
246 comb += lhs.eq(self.i.dividend << fw)
247 with m.Case(int(DP.SqrtRem)):
248 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
249 with m.Case(int(DP.RSqrtRem)):
250 comb += lhs.eq(1 << (fw * 3))
251
252 comb += self.o.compare_lhs.eq(lhs)
253 comb += self.o.compare_rhs.eq(0)
254 comb += self.o.operation.eq(self.i.operation)
255
256 return m
257
258
259 class Trial(Elaboratable):
260 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
261 self.core_config = core_config
262 self.trial_bits = trial_bits
263 self.current_shift = current_shift
264 self.log2_radix = log2_radix
265 bw = core_config.bit_width
266 self.divisor_radicand = Signal(bw, reset_less=True)
267 self.quotient_root = Signal(bw, reset_less=True)
268 self.root_times_radicand = Signal(bw * 2, reset_less=True)
269 self.compare_rhs = Signal(bw * 3, reset_less=True)
270 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
271 self.operation = DP.create_signal(reset_less=True)
272
273 def elaborate(self, platform):
274
275 m = Module()
276 comb = m.d.comb
277
278 dr = self.divisor_radicand
279 qr = self.quotient_root
280 rr = self.root_times_radicand
281
282 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
283 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
284 self.log2_radix * 2)
285
286 tblen = self.core_config.bit_width+self.log2_radix
287 tblen2 = self.core_config.bit_width+self.log2_radix*2
288 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
289 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
290
291 with m.Switch(self.operation):
292 # UDivRem
293 with m.Case(int(DP.UDivRem)):
294 dr_times_trial_bits = Signal(tblen, reset_less=True)
295 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
296 div_rhs = self.compare_rhs
297
298 div_term1 = dr_times_trial_bits
299 div_term1_shift = self.core_config.fract_width
300 div_term1_shift += self.current_shift
301 div_rhs += div_term1 << div_term1_shift
302
303 comb += self.trial_compare_rhs.eq(div_rhs)
304
305 # SqrtRem
306 with m.Case(int(DP.SqrtRem)):
307 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
308 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
309 sqrt_rhs = self.compare_rhs
310
311 sqrt_term1 = qr_times_trial_bits
312 sqrt_term1_shift = self.core_config.fract_width
313 sqrt_term1_shift += self.current_shift + 1
314 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
315 sqrt_term2 = trial_bits_sqrd_sig
316 sqrt_term2_shift = self.core_config.fract_width
317 sqrt_term2_shift += self.current_shift * 2
318 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
319
320 comb += self.trial_compare_rhs.eq(sqrt_rhs)
321
322 # RSqrtRem
323 with m.Case(int(DP.RSqrtRem)):
324 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
325 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
326 rsqrt_rhs = self.compare_rhs
327
328 rsqrt_term1 = rr_times_trial_bits
329 rsqrt_term1_shift = self.current_shift + 1
330 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
331 rsqrt_term2 = dr_times_trial_bits_sqrd
332 rsqrt_term2_shift = self.current_shift * 2
333 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
334
335 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
336
337 return m
338
339
340 class DivPipeCoreCalculateStage(Elaboratable):
341 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
342
343 def __init__(self, core_config, stage_index):
344 """ Create a ``DivPipeCoreSetupStage`` instance. """
345 assert stage_index in range(core_config.n_stages)
346 self.core_config = core_config
347 self.stage_index = stage_index
348 self.i = self.ispec()
349 self.o = self.ospec()
350
351 def ispec(self):
352 """ Get the input spec for this pipeline stage. """
353 return DivPipeCoreInterstageData(self.core_config)
354
355 def ospec(self):
356 """ Get the output spec for this pipeline stage. """
357 return DivPipeCoreInterstageData(self.core_config)
358
359 def setup(self, m, i):
360 """ Pipeline stage setup. """
361 setattr(m.submodules,
362 f"div_pipe_core_calculate_{self.stage_index}",
363 self)
364 m.d.comb += self.i.eq(i)
365
366 def process(self, i):
367 """ Pipeline stage process. """
368 return self.o
369
370 def elaborate(self, platform):
371 """ Elaborate into ``Module``. """
372 m = Module()
373 comb = m.d.comb
374
375 # copy invariant inputs to outputs (for next stage)
376 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
377 comb += self.o.operation.eq(self.i.operation)
378 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
379
380 # constants
381 log2_radix = self.core_config.log2_radix
382 current_shift = self.core_config.bit_width
383 current_shift -= self.stage_index * log2_radix
384 log2_radix = min(log2_radix, current_shift)
385 assert log2_radix > 0
386 current_shift -= log2_radix
387 print(f"DivPipeCoreCalc: stage {self.stage_index}"
388 + f" of {self.core_config.n_stages} handling "
389 + f"bits [{current_shift}, {current_shift+log2_radix})"
390 + f" of {self.core_config.bit_width}")
391 radix = 1 << log2_radix
392
393 # trials within this radix range. carried out by Trial module,
394 # results stored in pass_flags. pass_flags are unary priority.
395 trial_compare_rhs_values = []
396 pfl = []
397 for trial_bits in range(radix):
398 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
399 setattr(m.submodules, "trial%d" % trial_bits, t)
400
401 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
402 comb += t.quotient_root.eq(self.i.quotient_root)
403 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
404 comb += t.compare_rhs.eq(self.i.compare_rhs)
405 comb += t.operation.eq(self.i.operation)
406
407 # get the trial output
408 trial_compare_rhs_values.append(t.trial_compare_rhs)
409
410 # make the trial comparison against the [invariant] lhs.
411 # trial_compare_rhs is always decreasing as trial_bits increases
412 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
413 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
414 pfl.append(pass_flag)
415
416 # Cat all the pass flags list together (easier to handle, below)
417 pass_flags = Signal(radix, reset_less=True)
418 comb += pass_flags.eq(Cat(*pfl))
419
420 # convert pass_flags (unary priority) to next_bits (binary index)
421 #
422 # Assumes that for each set bit in pass_flag, all previous bits are
423 # also set.
424 #
425 # Assumes that pass_flag[0] is always set (since
426 # compare_lhs >= compare_rhs is a pipeline invariant).
427
428 m.submodules.pe = pe = PriorityEncoder(radix)
429 next_bits = Signal(log2_radix, reset_less=True)
430 comb += pe.i.eq(~pass_flags)
431 with m.If(~pe.n):
432 comb += next_bits.eq(pe.o-1)
433 with m.Else():
434 comb += next_bits.eq(radix-1)
435
436 # get the highest passing rhs trial (indexed by next_bits)
437 ta = Array(trial_compare_rhs_values)
438 comb += self.o.compare_rhs.eq(ta[next_bits])
439
440 # create outputs for next phase
441 qr = self.i.quotient_root | (next_bits << current_shift)
442 rr = self.i.root_times_radicand + ((self.i.divisor_radicand * next_bits)
443 << current_shift)
444 comb += self.o.quotient_root.eq(qr)
445 comb += self.o.root_times_radicand.eq(rr)
446
447 return m
448
449
450 class DivPipeCoreFinalStage(Elaboratable):
451 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
452
453 def __init__(self, core_config):
454 """ Create a ``DivPipeCoreFinalStage`` instance."""
455 self.core_config = core_config
456 self.i = self.ispec()
457 self.o = self.ospec()
458
459 def ispec(self):
460 """ Get the input spec for this pipeline stage."""
461 return DivPipeCoreInterstageData(self.core_config)
462
463 def ospec(self):
464 """ Get the output spec for this pipeline stage."""
465 return DivPipeCoreOutputData(self.core_config)
466
467 def setup(self, m, i):
468 """ Pipeline stage setup. """
469 m.submodules.div_pipe_core_final = self
470 m.d.comb += self.i.eq(i)
471
472 def process(self, i):
473 """ Pipeline stage process. """
474 return self.o # return processed data (ignore i)
475
476 def elaborate(self, platform):
477 """ Elaborate into ``Module``. """
478 m = Module()
479 comb = m.d.comb
480
481 comb += self.o.quotient_root.eq(self.i.quotient_root)
482 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
483
484 return m