e7208ee55210bf1b8ba657a70bdc75524f919548
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
22 from nmigen.lib.coding import PriorityEncoder
23 from nmutil.util import treereduce
24 import enum
25 import operator
26
27
28 class DivPipeCoreOperation(enum.Enum):
29 """ Operation for ``DivPipeCore``.
30
31 :attribute UDivRem: unsigned divide/remainder.
32 :attribute SqrtRem: square-root/remainder.
33 :attribute RSqrtRem: reciprocal-square-root/remainder.
34 """
35
36 SqrtRem = 0
37 UDivRem = 1
38 RSqrtRem = 2
39
40 def __int__(self):
41 """ Convert to int. """
42 return self.value
43
44 @classmethod
45 def create_signal(cls, *, src_loc_at=0, **kwargs):
46 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
47 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
48 src_loc_at=(src_loc_at + 1),
49 decoder=lambda v: str(cls(v)),
50 **kwargs)
51
52
53 DP = DivPipeCoreOperation
54
55
56 class DivPipeCoreConfig:
57 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
58
59 :attribute bit_width: base bit-width.
60 :attribute fract_width: base fract-width. Specifies location of base-2
61 radix point.
62 :attribute log2_radix: number of bits of ``quotient_root`` that should be
63 computed per pipeline stage.
64 """
65
66 def __init__(self, bit_width, fract_width, log2_radix, supported=None):
67 """ Create a ``DivPipeCoreConfig`` instance. """
68 self.bit_width = bit_width
69 self.fract_width = fract_width
70 self.log2_radix = log2_radix
71 if supported is None:
72 supported = [DP.SqrtRem, DP.UDivRem, DP.RSqrtRem]
73 self.supported = supported
74 print(f"{self}: n_stages={self.n_stages}")
75
76 def __repr__(self):
77 """ Get repr. """
78 return f"DivPipeCoreConfig({self.bit_width}, " \
79 + f"{self.fract_width}, {self.log2_radix})"
80
81 @property
82 def n_stages(self):
83 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
84 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
85
86
87 class DivPipeCoreInputData:
88 """ input data type for ``DivPipeCore``.
89
90 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
91 configuration to be used.
92 :attribute dividend: dividend for div/rem. Signal with a bit-width of
93 ``core_config.bit_width + core_config.fract_width`` and a fract-width
94 of ``core_config.fract_width * 2`` bits.
95 :attribute divisor_radicand: divisor for div/rem and radicand for
96 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
97 fract-width of ``core_config.fract_width`` bits.
98 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
99 """
100
101 def __init__(self, core_config, reset_less=True):
102 """ Create a ``DivPipeCoreInputData`` instance. """
103 self.core_config = core_config
104 bw = core_config.bit_width
105 fw = core_config.fract_width
106 self.dividend = Signal(bw + fw, reset_less=reset_less)
107 self.divisor_radicand = Signal(bw, reset_less=reset_less)
108 self.operation = DP.create_signal(reset_less=reset_less)
109
110 def __iter__(self):
111 """ Get member signals. """
112 yield self.dividend
113 yield self.divisor_radicand
114 yield self.operation
115
116 def eq(self, rhs):
117 """ Assign member signals. """
118 return [self.dividend.eq(rhs.dividend),
119 self.divisor_radicand.eq(rhs.divisor_radicand),
120 self.operation.eq(rhs.operation),
121 ]
122
123
124 class DivPipeCoreInterstageData:
125 """ interstage data type for ``DivPipeCore``.
126
127 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
128 configuration to be used.
129 :attribute divisor_radicand: divisor for div/rem and radicand for
130 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
131 fract-width of ``core_config.fract_width`` bits.
132 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
133 :attribute quotient_root: the quotient or root part of the result of the
134 operation. Signal with a bit-width of ``core_config.bit_width`` and a
135 fract-width of ``core_config.fract_width`` bits.
136 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
137 Signal with a bit-width of ``core_config.bit_width * 2`` and a
138 fract-width of ``core_config.fract_width * 2`` bits.
139 :attribute compare_lhs: The left-hand-side of the comparison in the
140 equation to be solved. Signal with a bit-width of
141 ``core_config.bit_width * 3`` and a fract-width of
142 ``core_config.fract_width * 3`` bits.
143 :attribute compare_rhs: The right-hand-side of the comparison in the
144 equation to be solved. Signal with a bit-width of
145 ``core_config.bit_width * 3`` and a fract-width of
146 ``core_config.fract_width * 3`` bits.
147 """
148
149 def __init__(self, core_config, reset_less=True):
150 """ Create a ``DivPipeCoreInterstageData`` instance. """
151 self.core_config = core_config
152 bw = core_config.bit_width
153 if core_config.supported == [DP.UDivRem]:
154 self.compare_len = bw * 2
155 else:
156 self.compare_len = bw * 3
157 self.divisor_radicand = Signal(bw, reset_less=reset_less)
158 self.operation = DP.create_signal(reset_less=reset_less)
159 self.quotient_root = Signal(bw, reset_less=reset_less)
160 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
161 self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
162 self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
163
164 def __iter__(self):
165 """ Get member signals. """
166 yield self.divisor_radicand
167 yield self.operation
168 yield self.quotient_root
169 yield self.root_times_radicand
170 yield self.compare_lhs
171 yield self.compare_rhs
172
173 def eq(self, rhs):
174 """ Assign member signals. """
175 return [self.divisor_radicand.eq(rhs.divisor_radicand),
176 self.operation.eq(rhs.operation),
177 self.quotient_root.eq(rhs.quotient_root),
178 self.root_times_radicand.eq(rhs.root_times_radicand),
179 self.compare_lhs.eq(rhs.compare_lhs),
180 self.compare_rhs.eq(rhs.compare_rhs)]
181
182
183 class DivPipeCoreOutputData:
184 """ output data type for ``DivPipeCore``.
185
186 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
187 configuration to be used.
188 :attribute quotient_root: the quotient or root part of the result of the
189 operation. Signal with a bit-width of ``core_config.bit_width`` and a
190 fract-width of ``core_config.fract_width`` bits.
191 :attribute remainder: the remainder part of the result of the operation.
192 Signal with a bit-width of ``core_config.bit_width * 3`` and a
193 fract-width of ``core_config.fract_width * 3`` bits.
194 """
195
196 def __init__(self, core_config, reset_less=True):
197 """ Create a ``DivPipeCoreOutputData`` instance. """
198 self.core_config = core_config
199 bw = core_config.bit_width
200 self.quotient_root = Signal(bw, reset_less=reset_less)
201 self.remainder = Signal(bw * 3, reset_less=reset_less)
202
203 def __iter__(self):
204 """ Get member signals. """
205 yield self.quotient_root
206 yield self.remainder
207 return
208
209 def eq(self, rhs):
210 """ Assign member signals. """
211 return [self.quotient_root.eq(rhs.quotient_root),
212 self.remainder.eq(rhs.remainder)]
213
214
215 class DivPipeCoreSetupStage(Elaboratable):
216 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
217
218 def __init__(self, core_config):
219 """ Create a ``DivPipeCoreSetupStage`` instance."""
220 self.core_config = core_config
221 self.i = self.ispec()
222 self.o = self.ospec()
223
224 def ispec(self):
225 """ Get the input spec for this pipeline stage."""
226 return DivPipeCoreInputData(self.core_config)
227
228 def ospec(self):
229 """ Get the output spec for this pipeline stage."""
230 return DivPipeCoreInterstageData(self.core_config)
231
232 def setup(self, m, i):
233 """ Pipeline stage setup. """
234 m.submodules.div_pipe_core_setup = self
235 m.d.comb += self.i.eq(i)
236
237 def process(self, i):
238 """ Pipeline stage process. """
239 return self.o # return processed data (ignore i)
240
241 def elaborate(self, platform):
242 """ Elaborate into ``Module``. """
243 m = Module()
244 comb = m.d.comb
245
246 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
247 comb += self.o.quotient_root.eq(0)
248 comb += self.o.root_times_radicand.eq(0)
249
250 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
251 fw = self.core_config.fract_width
252
253 with m.Switch(self.i.operation):
254 with m.Case(int(DP.UDivRem)):
255 comb += lhs.eq(self.i.dividend << fw)
256 with m.Case(int(DP.SqrtRem)):
257 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
258 with m.Case(int(DP.RSqrtRem)):
259 comb += lhs.eq(1 << (fw * 3))
260
261 comb += self.o.compare_lhs.eq(lhs)
262 comb += self.o.compare_rhs.eq(0)
263 comb += self.o.operation.eq(self.i.operation)
264
265 return m
266
267
268 class Trial(Elaboratable):
269 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
270 self.core_config = core_config
271 self.trial_bits = trial_bits
272 self.current_shift = current_shift
273 self.log2_radix = log2_radix
274 bw = core_config.bit_width
275 if core_config.supported == [DP.UDivRem]:
276 self.compare_len = bw * 2
277 else:
278 self.compare_len = bw * 3
279 self.divisor_radicand = Signal(bw, reset_less=True)
280 self.quotient_root = Signal(bw, reset_less=True)
281 self.root_times_radicand = Signal(bw * 2, reset_less=True)
282 self.compare_rhs = Signal(bw * 3, reset_less=True)
283 self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
284 self.operation = DP.create_signal(reset_less=True)
285
286 def elaborate(self, platform):
287
288 m = Module()
289 comb = m.d.comb
290
291 cc = self.core_config
292 dr = self.divisor_radicand
293
294 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
295 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
296 self.log2_radix * 2)
297
298 tblen = self.core_config.bit_width+self.log2_radix
299
300 # UDivRem
301 if DP.UDivRem in cc.supported:
302 with m.If(self.operation == int(DP.UDivRem)):
303 dr_times_trial_bits = Signal(tblen, reset_less=True)
304 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
305 div_rhs = self.compare_rhs
306
307 div_term1 = dr_times_trial_bits
308 div_term1_shift = self.core_config.fract_width
309 div_term1_shift += self.current_shift
310 div_rhs += div_term1 << div_term1_shift
311
312 comb += self.trial_compare_rhs.eq(div_rhs)
313
314 # SqrtRem
315 if DP.SqrtRem in cc.supported:
316 with m.If(self.operation == int(DP.SqrtRem)):
317 qr = self.quotient_root
318 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
319 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
320 sqrt_rhs = self.compare_rhs
321
322 sqrt_term1 = qr_times_trial_bits
323 sqrt_term1_shift = self.core_config.fract_width
324 sqrt_term1_shift += self.current_shift + 1
325 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
326 sqrt_term2 = trial_bits_sqrd_sig
327 sqrt_term2_shift = self.core_config.fract_width
328 sqrt_term2_shift += self.current_shift * 2
329 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
330
331 comb += self.trial_compare_rhs.eq(sqrt_rhs)
332
333 # RSqrtRem
334 if DP.RSqrtRem in cc.supported:
335 with m.If(self.operation == int(DP.RSqrtRem)):
336 rr = self.root_times_radicand
337 tblen2 = self.core_config.bit_width+self.log2_radix*2
338 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
339 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
340 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
341 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
342 rsqrt_rhs = self.compare_rhs
343
344 rsqrt_term1 = rr_times_trial_bits
345 rsqrt_term1_shift = self.current_shift + 1
346 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
347 rsqrt_term2 = dr_times_trial_bits_sqrd
348 rsqrt_term2_shift = self.current_shift * 2
349 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
350
351 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
352
353 return m
354
355
356 class DivPipeCoreCalculateStage(Elaboratable):
357 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
358
359 def __init__(self, core_config, stage_index):
360 """ Create a ``DivPipeCoreSetupStage`` instance. """
361 assert stage_index in range(core_config.n_stages)
362 self.core_config = core_config
363 self.stage_index = stage_index
364 self.i = self.ispec()
365 self.o = self.ospec()
366
367 def ispec(self):
368 """ Get the input spec for this pipeline stage. """
369 return DivPipeCoreInterstageData(self.core_config)
370
371 def ospec(self):
372 """ Get the output spec for this pipeline stage. """
373 return DivPipeCoreInterstageData(self.core_config)
374
375 def setup(self, m, i):
376 """ Pipeline stage setup. """
377 setattr(m.submodules,
378 f"div_pipe_core_calculate_{self.stage_index}",
379 self)
380 m.d.comb += self.i.eq(i)
381
382 def process(self, i):
383 """ Pipeline stage process. """
384 return self.o
385
386 def elaborate(self, platform):
387 """ Elaborate into ``Module``. """
388 m = Module()
389 comb = m.d.comb
390 cc = self.core_config
391
392 # copy invariant inputs to outputs (for next stage)
393 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
394 comb += self.o.operation.eq(self.i.operation)
395 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
396
397 # constants
398 log2_radix = self.core_config.log2_radix
399 current_shift = self.core_config.bit_width
400 current_shift -= self.stage_index * log2_radix
401 log2_radix = min(log2_radix, current_shift)
402 assert log2_radix > 0
403 current_shift -= log2_radix
404 print(f"DivPipeCoreCalc: stage {self.stage_index}"
405 + f" of {self.core_config.n_stages} handling "
406 + f"bits [{current_shift}, {current_shift+log2_radix})"
407 + f" of {self.core_config.bit_width}")
408 radix = 1 << log2_radix
409
410 # trials within this radix range. carried out by Trial module,
411 # results stored in pass_flags. pass_flags are unary priority.
412 trial_compare_rhs_values = []
413 pfl = []
414 for trial_bits in range(radix):
415 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
416 setattr(m.submodules, "trial%d" % trial_bits, t)
417
418 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
419 comb += t.quotient_root.eq(self.i.quotient_root)
420 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
421 comb += t.compare_rhs.eq(self.i.compare_rhs)
422 comb += t.operation.eq(self.i.operation)
423
424 # get the trial output (needed even in pass_flags[0] case)
425 trial_compare_rhs_values.append(t.trial_compare_rhs)
426
427 # make the trial comparison against the [invariant] lhs.
428 # trial_compare_rhs is always decreasing as trial_bits increases
429 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
430 if trial_bits == 0:
431 # do not do first comparison: no point.
432 comb += pass_flag.eq(1)
433 else:
434 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
435 pfl.append(pass_flag)
436
437 # Cat all the pass flags list together (easier to handle, below)
438 pass_flags = Signal(radix, reset_less=True)
439 comb += pass_flags.eq(Cat(*pfl))
440
441 # convert pass_flags (unary priority) to next_bits (binary index)
442 #
443 # Assumes that for each set bit in pass_flag, all previous bits are
444 # also set.
445 #
446 # Assumes that pass_flag[0] is always set (since
447 # compare_lhs >= compare_rhs is a pipeline invariant).
448
449 m.submodules.pe = pe = PriorityEncoder(radix)
450 next_bits = Signal(log2_radix, reset_less=True)
451 comb += pe.i.eq(~pass_flags)
452 with m.If(~pe.n):
453 comb += next_bits.eq(pe.o-1)
454 with m.Else():
455 comb += next_bits.eq(radix-1)
456
457 # get the highest passing rhs trial. use treereduce because
458 # Array on such massively long numbers is insanely gate-hungry
459 crhs = []
460 tcrh = trial_compare_rhs_values
461 bw = self.core_config.bit_width
462 for i in range(radix):
463 nbe = Signal(reset_less=True)
464 comb += nbe.eq(next_bits == i)
465 crhs.append(Repl(nbe, bw*3) & tcrh[i])
466 comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
467 lambda x:x))
468
469 # create outputs for next phase
470 qr = self.i.quotient_root | (next_bits << current_shift)
471 comb += self.o.quotient_root.eq(qr)
472 if DP.RSqrtRem in cc.supported:
473 rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
474 next_bits) << current_shift)
475 comb += self.o.root_times_radicand.eq(rr)
476
477 return m
478
479
480 class DivPipeCoreFinalStage(Elaboratable):
481 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
482
483 def __init__(self, core_config):
484 """ Create a ``DivPipeCoreFinalStage`` instance."""
485 self.core_config = core_config
486 self.i = self.ispec()
487 self.o = self.ospec()
488
489 def ispec(self):
490 """ Get the input spec for this pipeline stage."""
491 return DivPipeCoreInterstageData(self.core_config)
492
493 def ospec(self):
494 """ Get the output spec for this pipeline stage."""
495 return DivPipeCoreOutputData(self.core_config)
496
497 def setup(self, m, i):
498 """ Pipeline stage setup. """
499 m.submodules.div_pipe_core_final = self
500 m.d.comb += self.i.eq(i)
501
502 def process(self, i):
503 """ Pipeline stage process. """
504 return self.o # return processed data (ignore i)
505
506 def elaborate(self, platform):
507 """ Elaborate into ``Module``. """
508 m = Module()
509 comb = m.d.comb
510
511 comb += self.o.quotient_root.eq(self.i.quotient_root)
512 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
513
514 return m