only pass in lhs bit_width * 2 for UDivRem
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
22 from nmigen.lib.coding import PriorityEncoder
23 from nmutil.util import treereduce
24 import enum
25 import operator
26
27
28 class DivPipeCoreOperation(enum.Enum):
29 """ Operation for ``DivPipeCore``.
30
31 :attribute UDivRem: unsigned divide/remainder.
32 :attribute SqrtRem: square-root/remainder.
33 :attribute RSqrtRem: reciprocal-square-root/remainder.
34 """
35
36 SqrtRem = 0
37 UDivRem = 1
38 RSqrtRem = 2
39
40 def __int__(self):
41 """ Convert to int. """
42 return self.value
43
44 @classmethod
45 def create_signal(cls, *, src_loc_at=0, **kwargs):
46 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
47 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
48 src_loc_at=(src_loc_at + 1),
49 decoder=lambda v: str(cls(v)),
50 **kwargs)
51
52
53 DP = DivPipeCoreOperation
54
55
56 class DivPipeCoreConfig:
57 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
58
59 :attribute bit_width: base bit-width.
60 :attribute fract_width: base fract-width. Specifies location of base-2
61 radix point.
62 :attribute log2_radix: number of bits of ``quotient_root`` that should be
63 computed per pipeline stage.
64 """
65
66 def __init__(self, bit_width, fract_width, log2_radix, supported=None):
67 """ Create a ``DivPipeCoreConfig`` instance. """
68 self.bit_width = bit_width
69 self.fract_width = fract_width
70 self.log2_radix = log2_radix
71 if supported is None:
72 supported = [DP.SqrtRem, DP.UDivRem, DP.RSqrtRem]
73 self.supported = supported
74 print(f"{self}: n_stages={self.n_stages}")
75
76 def __repr__(self):
77 """ Get repr. """
78 return f"DivPipeCoreConfig({self.bit_width}, " \
79 + f"{self.fract_width}, {self.log2_radix})"
80
81 @property
82 def n_stages(self):
83 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
84 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
85
86
87 class DivPipeCoreInputData:
88 """ input data type for ``DivPipeCore``.
89
90 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
91 configuration to be used.
92 :attribute dividend: dividend for div/rem. Signal with a bit-width of
93 ``core_config.bit_width + core_config.fract_width`` and a fract-width
94 of ``core_config.fract_width * 2`` bits.
95 :attribute divisor_radicand: divisor for div/rem and radicand for
96 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
97 fract-width of ``core_config.fract_width`` bits.
98 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
99 """
100
101 def __init__(self, core_config, reset_less=True):
102 """ Create a ``DivPipeCoreInputData`` instance. """
103 self.core_config = core_config
104 bw = core_config.bit_width
105 fw = core_config.fract_width
106 self.dividend = Signal(bw + fw, reset_less=reset_less)
107 self.divisor_radicand = Signal(bw, reset_less=reset_less)
108 self.operation = DP.create_signal(reset_less=reset_less)
109
110 def __iter__(self):
111 """ Get member signals. """
112 yield self.dividend
113 yield self.divisor_radicand
114 yield self.operation
115
116 def eq(self, rhs):
117 """ Assign member signals. """
118 return [self.dividend.eq(rhs.dividend),
119 self.divisor_radicand.eq(rhs.divisor_radicand),
120 self.operation.eq(rhs.operation),
121 ]
122
123
124 class DivPipeCoreInterstageData:
125 """ interstage data type for ``DivPipeCore``.
126
127 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
128 configuration to be used.
129 :attribute divisor_radicand: divisor for div/rem and radicand for
130 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
131 fract-width of ``core_config.fract_width`` bits.
132 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
133 :attribute quotient_root: the quotient or root part of the result of the
134 operation. Signal with a bit-width of ``core_config.bit_width`` and a
135 fract-width of ``core_config.fract_width`` bits.
136 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
137 Signal with a bit-width of ``core_config.bit_width * 2`` and a
138 fract-width of ``core_config.fract_width * 2`` bits.
139 :attribute compare_lhs: The left-hand-side of the comparison in the
140 equation to be solved. Signal with a bit-width of
141 ``core_config.bit_width * 3`` and a fract-width of
142 ``core_config.fract_width * 3`` bits.
143 :attribute compare_rhs: The right-hand-side of the comparison in the
144 equation to be solved. Signal with a bit-width of
145 ``core_config.bit_width * 3`` and a fract-width of
146 ``core_config.fract_width * 3`` bits.
147 """
148
149 def __init__(self, core_config, reset_less=True):
150 """ Create a ``DivPipeCoreInterstageData`` instance. """
151 self.core_config = core_config
152 bw = core_config.bit_width
153 if core_config.supported == [DP.UDivRem]:
154 self.compare_len = bw * 2
155 else:
156 self.compare_len = bw * 3
157 self.divisor_radicand = Signal(bw, reset_less=reset_less)
158 self.operation = DP.create_signal(reset_less=reset_less)
159 self.quotient_root = Signal(bw, reset_less=reset_less)
160 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
161 self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
162 self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
163
164 def __iter__(self):
165 """ Get member signals. """
166 yield self.divisor_radicand
167 yield self.operation
168 yield self.quotient_root
169 yield self.root_times_radicand
170 yield self.compare_lhs
171 yield self.compare_rhs
172
173 def eq(self, rhs):
174 """ Assign member signals. """
175 return [self.divisor_radicand.eq(rhs.divisor_radicand),
176 self.operation.eq(rhs.operation),
177 self.quotient_root.eq(rhs.quotient_root),
178 self.root_times_radicand.eq(rhs.root_times_radicand),
179 self.compare_lhs.eq(rhs.compare_lhs),
180 self.compare_rhs.eq(rhs.compare_rhs)]
181
182
183 class DivPipeCoreOutputData:
184 """ output data type for ``DivPipeCore``.
185
186 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
187 configuration to be used.
188 :attribute quotient_root: the quotient or root part of the result of the
189 operation. Signal with a bit-width of ``core_config.bit_width`` and a
190 fract-width of ``core_config.fract_width`` bits.
191 :attribute remainder: the remainder part of the result of the operation.
192 Signal with a bit-width of ``core_config.bit_width * 3`` and a
193 fract-width of ``core_config.fract_width * 3`` bits.
194 """
195
196 def __init__(self, core_config, reset_less=True):
197 """ Create a ``DivPipeCoreOutputData`` instance. """
198 self.core_config = core_config
199 bw = core_config.bit_width
200 if core_config.supported == [DP.UDivRem]:
201 self.compare_len = bw * 2
202 else:
203 self.compare_len = bw * 3
204 self.quotient_root = Signal(bw, reset_less=reset_less)
205 self.remainder = Signal(self.compare_len, reset_less=reset_less)
206
207 def __iter__(self):
208 """ Get member signals. """
209 yield self.quotient_root
210 yield self.remainder
211 return
212
213 def eq(self, rhs):
214 """ Assign member signals. """
215 return [self.quotient_root.eq(rhs.quotient_root),
216 self.remainder.eq(rhs.remainder)]
217
218
219 class DivPipeCoreSetupStage(Elaboratable):
220 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
221
222 def __init__(self, core_config):
223 """ Create a ``DivPipeCoreSetupStage`` instance."""
224 self.core_config = core_config
225 self.i = self.ispec()
226 self.o = self.ospec()
227 if core_config.supported == [DP.UDivRem]:
228 self.compare_len = bw * 2
229 else:
230 self.compare_len = bw * 3
231
232 def ispec(self):
233 """ Get the input spec for this pipeline stage."""
234 return DivPipeCoreInputData(self.core_config)
235
236 def ospec(self):
237 """ Get the output spec for this pipeline stage."""
238 return DivPipeCoreInterstageData(self.core_config)
239
240 def setup(self, m, i):
241 """ Pipeline stage setup. """
242 m.submodules.div_pipe_core_setup = self
243 m.d.comb += self.i.eq(i)
244
245 def process(self, i):
246 """ Pipeline stage process. """
247 return self.o # return processed data (ignore i)
248
249 def elaborate(self, platform):
250 """ Elaborate into ``Module``. """
251 m = Module()
252 comb = m.d.comb
253
254 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
255 comb += self.o.quotient_root.eq(0)
256 comb += self.o.root_times_radicand.eq(0)
257
258 lhs = Signal(self.compare_len, reset_less=True)
259 fw = self.core_config.fract_width
260
261 with m.Switch(self.i.operation):
262 with m.Case(int(DP.UDivRem)):
263 comb += lhs.eq(self.i.dividend << fw)
264 with m.Case(int(DP.SqrtRem)):
265 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
266 with m.Case(int(DP.RSqrtRem)):
267 comb += lhs.eq(1 << (fw * 3))
268
269 comb += self.o.compare_lhs.eq(lhs)
270 comb += self.o.compare_rhs.eq(0)
271 comb += self.o.operation.eq(self.i.operation)
272
273 return m
274
275
276 class Trial(Elaboratable):
277 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
278 self.core_config = core_config
279 self.trial_bits = trial_bits
280 self.current_shift = current_shift
281 self.log2_radix = log2_radix
282 bw = core_config.bit_width
283 if core_config.supported == [DP.UDivRem]:
284 self.compare_len = bw * 2
285 else:
286 self.compare_len = bw * 3
287 self.divisor_radicand = Signal(bw, reset_less=True)
288 self.quotient_root = Signal(bw, reset_less=True)
289 self.root_times_radicand = Signal(bw * 2, reset_less=True)
290 self.compare_rhs = Signal(self.compare_len, reset_less=True)
291 self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
292 self.operation = DP.create_signal(reset_less=True)
293
294 def elaborate(self, platform):
295
296 m = Module()
297 comb = m.d.comb
298
299 cc = self.core_config
300 dr = self.divisor_radicand
301
302 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
303 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
304 self.log2_radix * 2)
305
306 tblen = self.core_config.bit_width+self.log2_radix
307
308 # UDivRem
309 if DP.UDivRem in cc.supported:
310 with m.If(self.operation == int(DP.UDivRem)):
311 dr_times_trial_bits = Signal(tblen, reset_less=True)
312 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
313 div_rhs = self.compare_rhs
314
315 div_term1 = dr_times_trial_bits
316 div_term1_shift = self.core_config.fract_width
317 div_term1_shift += self.current_shift
318 div_rhs += div_term1 << div_term1_shift
319
320 comb += self.trial_compare_rhs.eq(div_rhs)
321
322 # SqrtRem
323 if DP.SqrtRem in cc.supported:
324 with m.If(self.operation == int(DP.SqrtRem)):
325 qr = self.quotient_root
326 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
327 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
328 sqrt_rhs = self.compare_rhs
329
330 sqrt_term1 = qr_times_trial_bits
331 sqrt_term1_shift = self.core_config.fract_width
332 sqrt_term1_shift += self.current_shift + 1
333 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
334 sqrt_term2 = trial_bits_sqrd_sig
335 sqrt_term2_shift = self.core_config.fract_width
336 sqrt_term2_shift += self.current_shift * 2
337 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
338
339 comb += self.trial_compare_rhs.eq(sqrt_rhs)
340
341 # RSqrtRem
342 if DP.RSqrtRem in cc.supported:
343 with m.If(self.operation == int(DP.RSqrtRem)):
344 rr = self.root_times_radicand
345 tblen2 = self.core_config.bit_width+self.log2_radix*2
346 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
347 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
348 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
349 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
350 rsqrt_rhs = self.compare_rhs
351
352 rsqrt_term1 = rr_times_trial_bits
353 rsqrt_term1_shift = self.current_shift + 1
354 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
355 rsqrt_term2 = dr_times_trial_bits_sqrd
356 rsqrt_term2_shift = self.current_shift * 2
357 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
358
359 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
360
361 return m
362
363
364 class DivPipeCoreCalculateStage(Elaboratable):
365 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
366
367 def __init__(self, core_config, stage_index):
368 """ Create a ``DivPipeCoreSetupStage`` instance. """
369 assert stage_index in range(core_config.n_stages)
370 self.core_config = core_config
371 bw = core_config.bit_width
372 if core_config.supported == [DP.UDivRem]:
373 self.compare_len = bw * 2
374 else:
375 self.compare_len = bw * 3
376 self.stage_index = stage_index
377 self.i = self.ispec()
378 self.o = self.ospec()
379
380 def ispec(self):
381 """ Get the input spec for this pipeline stage. """
382 return DivPipeCoreInterstageData(self.core_config)
383
384 def ospec(self):
385 """ Get the output spec for this pipeline stage. """
386 return DivPipeCoreInterstageData(self.core_config)
387
388 def setup(self, m, i):
389 """ Pipeline stage setup. """
390 setattr(m.submodules,
391 f"div_pipe_core_calculate_{self.stage_index}",
392 self)
393 m.d.comb += self.i.eq(i)
394
395 def process(self, i):
396 """ Pipeline stage process. """
397 return self.o
398
399 def elaborate(self, platform):
400 """ Elaborate into ``Module``. """
401 m = Module()
402 comb = m.d.comb
403 cc = self.core_config
404
405 # copy invariant inputs to outputs (for next stage)
406 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
407 comb += self.o.operation.eq(self.i.operation)
408 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
409
410 # constants
411 log2_radix = self.core_config.log2_radix
412 current_shift = self.core_config.bit_width
413 current_shift -= self.stage_index * log2_radix
414 log2_radix = min(log2_radix, current_shift)
415 assert log2_radix > 0
416 current_shift -= log2_radix
417 print(f"DivPipeCoreCalc: stage {self.stage_index}"
418 + f" of {self.core_config.n_stages} handling "
419 + f"bits [{current_shift}, {current_shift+log2_radix})"
420 + f" of {self.core_config.bit_width}")
421 radix = 1 << log2_radix
422
423 # trials within this radix range. carried out by Trial module,
424 # results stored in pass_flags. pass_flags are unary priority.
425 trial_compare_rhs_values = []
426 pfl = []
427 for trial_bits in range(radix):
428 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
429 setattr(m.submodules, "trial%d" % trial_bits, t)
430
431 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
432 comb += t.quotient_root.eq(self.i.quotient_root)
433 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
434 comb += t.compare_rhs.eq(self.i.compare_rhs)
435 comb += t.operation.eq(self.i.operation)
436
437 # get the trial output (needed even in pass_flags[0] case)
438 trial_compare_rhs_values.append(t.trial_compare_rhs)
439
440 # make the trial comparison against the [invariant] lhs.
441 # trial_compare_rhs is always decreasing as trial_bits increases
442 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
443 if trial_bits == 0:
444 # do not do first comparison: no point.
445 comb += pass_flag.eq(1)
446 else:
447 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
448 pfl.append(pass_flag)
449
450 # Cat all the pass flags list together (easier to handle, below)
451 pass_flags = Signal(radix, reset_less=True)
452 comb += pass_flags.eq(Cat(*pfl))
453
454 # convert pass_flags (unary priority) to next_bits (binary index)
455 #
456 # Assumes that for each set bit in pass_flag, all previous bits are
457 # also set.
458 #
459 # Assumes that pass_flag[0] is always set (since
460 # compare_lhs >= compare_rhs is a pipeline invariant).
461
462 m.submodules.pe = pe = PriorityEncoder(radix)
463 next_bits = Signal(log2_radix, reset_less=True)
464 comb += pe.i.eq(~pass_flags)
465 with m.If(~pe.n):
466 comb += next_bits.eq(pe.o-1)
467 with m.Else():
468 comb += next_bits.eq(radix-1)
469
470 # get the highest passing rhs trial. use treereduce because
471 # Array on such massively long numbers is insanely gate-hungry
472 crhs = []
473 tcrh = trial_compare_rhs_values
474 for i in range(radix):
475 nbe = Signal(reset_less=True)
476 comb += nbe.eq(next_bits == i)
477 crhs.append(Repl(nbe, self.compare_len) & tcrh[i])
478 comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
479 lambda x:x))
480
481 # create outputs for next phase
482 qr = self.i.quotient_root | (next_bits << current_shift)
483 comb += self.o.quotient_root.eq(qr)
484 if DP.RSqrtRem in cc.supported:
485 rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
486 next_bits) << current_shift)
487 comb += self.o.root_times_radicand.eq(rr)
488
489 return m
490
491
492 class DivPipeCoreFinalStage(Elaboratable):
493 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
494
495 def __init__(self, core_config):
496 """ Create a ``DivPipeCoreFinalStage`` instance."""
497 self.core_config = core_config
498 self.i = self.ispec()
499 self.o = self.ospec()
500
501 def ispec(self):
502 """ Get the input spec for this pipeline stage."""
503 return DivPipeCoreInterstageData(self.core_config)
504
505 def ospec(self):
506 """ Get the output spec for this pipeline stage."""
507 return DivPipeCoreOutputData(self.core_config)
508
509 def setup(self, m, i):
510 """ Pipeline stage setup. """
511 m.submodules.div_pipe_core_final = self
512 m.d.comb += self.i.eq(i)
513
514 def process(self, i):
515 """ Pipeline stage process. """
516 return self.o # return processed data (ignore i)
517
518 def elaborate(self, platform):
519 """ Elaborate into ``Module``. """
520 m = Module()
521 comb = m.d.comb
522
523 comb += self.o.quotient_root.eq(self.i.quotient_root)
524 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
525
526 return m