2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
22 from nmigen.lib.coding import PriorityEncoder
23 import enum
26 class DivPipeCoreConfig:
27 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
29 :attribute bit_width: base bit-width.
30 :attribute fract_width: base fract-width. Specifies location of base-2
32 :attribute log2_radix: number of bits of ``quotient_root`` that should be
33 computed per pipeline stage.
34 """
36 def __init__(self, bit_width, fract_width, log2_radix):
37 """ Create a ``DivPipeCoreConfig`` instance. """
38 self.bit_width = bit_width
39 self.fract_width = fract_width
41 print(f"{self}: n_stages={self.n_stages}")
43 def __repr__(self):
44 """ Get repr. """
45 return f"DivPipeCoreConfig({self.bit_width}, " \
48 @property
49 def n_stages(self):
50 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
54 class DivPipeCoreOperation(enum.Enum):
55 """ Operation for ``DivPipeCore``.
57 :attribute UDivRem: unsigned divide/remainder.
58 :attribute SqrtRem: square-root/remainder.
59 :attribute RSqrtRem: reciprocal-square-root/remainder.
60 """
62 SqrtRem = 0
63 UDivRem = 1
64 RSqrtRem = 2
66 def __int__(self):
67 """ Convert to int. """
68 return self.value
70 @classmethod
71 def create_signal(cls, *, src_loc_at=0, **kwargs):
72 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
73 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
74 src_loc_at=(src_loc_at + 1),
75 decoder=lambda v: str(cls(v)),
76 **kwargs)
79 DP = DivPipeCoreOperation
82 class DivPipeCoreInputData:
83 """ input data type for ``DivPipeCore``.
85 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
86 configuration to be used.
87 :attribute dividend: dividend for div/rem. Signal with a bit-width of
88 ``core_config.bit_width + core_config.fract_width`` and a fract-width
89 of ``core_config.fract_width * 2`` bits.
91 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
92 fract-width of ``core_config.fract_width`` bits.
93 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
94 """
96 def __init__(self, core_config, reset_less=True):
97 """ Create a ``DivPipeCoreInputData`` instance. """
98 self.core_config = core_config
99 bw = core_config.bit_width
100 fw = core_config.fract_width
101 self.dividend = Signal(bw + fw, reset_less=reset_less)
103 self.operation = DP.create_signal(reset_less=reset_less)
105 def __iter__(self):
106 """ Get member signals. """
107 yield self.dividend
109 yield self.operation
111 def eq(self, rhs):
112 """ Assign member signals. """
113 return [self.dividend.eq(rhs.dividend),
115 self.operation.eq(rhs.operation),
116 ]
119 class DivPipeCoreInterstageData:
120 """ interstage data type for ``DivPipeCore``.
122 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
123 configuration to be used.
125 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
126 fract-width of ``core_config.fract_width`` bits.
127 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
128 :attribute quotient_root: the quotient or root part of the result of the
129 operation. Signal with a bit-width of ``core_config.bit_width`` and a
130 fract-width of ``core_config.fract_width`` bits.
132 Signal with a bit-width of ``core_config.bit_width * 2`` and a
133 fract-width of ``core_config.fract_width * 2`` bits.
134 :attribute compare_lhs: The left-hand-side of the comparison in the
135 equation to be solved. Signal with a bit-width of
136 ``core_config.bit_width * 3`` and a fract-width of
137 ``core_config.fract_width * 3`` bits.
138 :attribute compare_rhs: The right-hand-side of the comparison in the
139 equation to be solved. Signal with a bit-width of
140 ``core_config.bit_width * 3`` and a fract-width of
141 ``core_config.fract_width * 3`` bits.
142 """
144 def __init__(self, core_config, reset_less=True):
145 """ Create a ``DivPipeCoreInterstageData`` instance. """
146 self.core_config = core_config
147 bw = core_config.bit_width
149 self.operation = DP.create_signal(reset_less=reset_less)
150 self.quotient_root = Signal(bw, reset_less=reset_less)
151 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
152 self.compare_lhs = Signal(bw * 3, reset_less=reset_less)
153 self.compare_rhs = Signal(bw * 3, reset_less=reset_less)
155 def __iter__(self):
156 """ Get member signals. """
158 yield self.operation
159 yield self.quotient_root
161 yield self.compare_lhs
162 yield self.compare_rhs
164 def eq(self, rhs):
165 """ Assign member signals. """
167 self.operation.eq(rhs.operation),
168 self.quotient_root.eq(rhs.quotient_root),
170 self.compare_lhs.eq(rhs.compare_lhs),
171 self.compare_rhs.eq(rhs.compare_rhs)]
174 class DivPipeCoreOutputData:
175 """ output data type for ``DivPipeCore``.
177 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
178 configuration to be used.
179 :attribute quotient_root: the quotient or root part of the result of the
180 operation. Signal with a bit-width of ``core_config.bit_width`` and a
181 fract-width of ``core_config.fract_width`` bits.
182 :attribute remainder: the remainder part of the result of the operation.
183 Signal with a bit-width of ``core_config.bit_width * 3`` and a
184 fract-width of ``core_config.fract_width * 3`` bits.
185 """
187 def __init__(self, core_config, reset_less=True):
188 """ Create a ``DivPipeCoreOutputData`` instance. """
189 self.core_config = core_config
190 bw = core_config.bit_width
191 self.quotient_root = Signal(bw, reset_less=reset_less)
192 self.remainder = Signal(bw * 3, reset_less=reset_less)
194 def __iter__(self):
195 """ Get member signals. """
196 yield self.quotient_root
197 yield self.remainder
198 return
200 def eq(self, rhs):
201 """ Assign member signals. """
202 return [self.quotient_root.eq(rhs.quotient_root),
203 self.remainder.eq(rhs.remainder)]
206 class DivPipeCoreSetupStage(Elaboratable):
207 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
209 def __init__(self, core_config):
210 """ Create a ``DivPipeCoreSetupStage`` instance."""
211 self.core_config = core_config
212 self.i = self.ispec()
213 self.o = self.ospec()
215 def ispec(self):
216 """ Get the input spec for this pipeline stage."""
217 return DivPipeCoreInputData(self.core_config)
219 def ospec(self):
220 """ Get the output spec for this pipeline stage."""
221 return DivPipeCoreInterstageData(self.core_config)
223 def setup(self, m, i):
224 """ Pipeline stage setup. """
225 m.submodules.div_pipe_core_setup = self
226 m.d.comb += self.i.eq(i)
228 def process(self, i):
229 """ Pipeline stage process. """
230 return self.o # return processed data (ignore i)
232 def elaborate(self, platform):
233 """ Elaborate into ``Module``. """
234 m = Module()
235 comb = m.d.comb
238 comb += self.o.quotient_root.eq(0)
241 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
242 fw = self.core_config.fract_width
244 with m.Switch(self.i.operation):
245 with m.Case(int(DP.UDivRem)):
246 comb += lhs.eq(self.i.dividend << fw)
247 with m.Case(int(DP.SqrtRem)):
248 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
249 with m.Case(int(DP.RSqrtRem)):
250 comb += lhs.eq(1 << (fw * 3))
252 comb += self.o.compare_lhs.eq(lhs)
253 comb += self.o.compare_rhs.eq(0)
254 comb += self.o.operation.eq(self.i.operation)
256 return m
259 class Trial(Elaboratable):
260 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
261 self.core_config = core_config
262 self.trial_bits = trial_bits
263 self.current_shift = current_shift
265 bw = core_config.bit_width
267 self.quotient_root = Signal(bw, reset_less=True)
268 self.root_times_radicand = Signal(bw * 2, reset_less=True)
269 self.compare_rhs = Signal(bw * 3, reset_less=True)
270 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
271 self.operation = DP.create_signal(reset_less=True)
273 def elaborate(self, platform):
275 m = Module()
276 comb = m.d.comb
279 qr = self.quotient_root
283 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
288 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
289 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
291 with m.Switch(self.operation):
292 # UDivRem
293 with m.Case(int(DP.UDivRem)):
294 dr_times_trial_bits = Signal(tblen, reset_less=True)
295 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
296 div_rhs = self.compare_rhs
298 div_term1 = dr_times_trial_bits
299 div_term1_shift = self.core_config.fract_width
300 div_term1_shift += self.current_shift
301 div_rhs += div_term1 << div_term1_shift
303 comb += self.trial_compare_rhs.eq(div_rhs)
305 # SqrtRem
306 with m.Case(int(DP.SqrtRem)):
307 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
308 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
309 sqrt_rhs = self.compare_rhs
311 sqrt_term1 = qr_times_trial_bits
312 sqrt_term1_shift = self.core_config.fract_width
313 sqrt_term1_shift += self.current_shift + 1
314 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
315 sqrt_term2 = trial_bits_sqrd_sig
316 sqrt_term2_shift = self.core_config.fract_width
317 sqrt_term2_shift += self.current_shift * 2
318 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
320 comb += self.trial_compare_rhs.eq(sqrt_rhs)
322 # RSqrtRem
323 with m.Case(int(DP.RSqrtRem)):
324 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
325 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
326 rsqrt_rhs = self.compare_rhs
328 rsqrt_term1 = rr_times_trial_bits
329 rsqrt_term1_shift = self.current_shift + 1
330 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
331 rsqrt_term2 = dr_times_trial_bits_sqrd
332 rsqrt_term2_shift = self.current_shift * 2
333 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
335 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
337 return m
340 class DivPipeCoreCalculateStage(Elaboratable):
341 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
343 def __init__(self, core_config, stage_index):
344 """ Create a ``DivPipeCoreSetupStage`` instance. """
345 assert stage_index in range(core_config.n_stages)
346 self.core_config = core_config
347 self.stage_index = stage_index
348 self.i = self.ispec()
349 self.o = self.ospec()
351 def ispec(self):
352 """ Get the input spec for this pipeline stage. """
353 return DivPipeCoreInterstageData(self.core_config)
355 def ospec(self):
356 """ Get the output spec for this pipeline stage. """
357 return DivPipeCoreInterstageData(self.core_config)
359 def setup(self, m, i):
360 """ Pipeline stage setup. """
361 setattr(m.submodules,
362 f"div_pipe_core_calculate_{self.stage_index}",
363 self)
364 m.d.comb += self.i.eq(i)
366 def process(self, i):
367 """ Pipeline stage process. """
368 return self.o
370 def elaborate(self, platform):
371 """ Elaborate into ``Module``. """
372 m = Module()
373 comb = m.d.comb
375 # copy invariant inputs to outputs (for next stage)
377 comb += self.o.operation.eq(self.i.operation)
378 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
380 # constants
382 current_shift = self.core_config.bit_width
383 current_shift -= self.stage_index * log2_radix
387 print(f"DivPipeCoreCalc: stage {self.stage_index}"
388 + f" of {self.core_config.n_stages} handling "
390 + f" of {self.core_config.bit_width}")
393 # trials within this radix range. carried out by Trial module,
394 # results stored in pass_flags. pass_flags are unary priority.
395 trial_compare_rhs_values = []
396 pfl = []
398 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
399 setattr(m.submodules, "trial%d" % trial_bits, t)
402 comb += t.quotient_root.eq(self.i.quotient_root)
404 comb += t.compare_rhs.eq(self.i.compare_rhs)
405 comb += t.operation.eq(self.i.operation)
407 # get the trial output
408 trial_compare_rhs_values.append(t.trial_compare_rhs)
410 # make the trial comparison against the [invariant] lhs.
411 # trial_compare_rhs is always decreasing as trial_bits increases
412 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
413 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
414 pfl.append(pass_flag)
416 # Cat all the pass flags list together (easier to handle, below)
418 comb += pass_flags.eq(Cat(*pfl))
420 # convert pass_flags (unary priority) to next_bits (binary index)
421 #
422 # Assumes that for each set bit in pass_flag, all previous bits are
423 # also set.
424 #
425 # Assumes that pass_flag[0] is always set (since
426 # compare_lhs >= compare_rhs is a pipeline invariant).
428 m.submodules.pe = pe = PriorityEncoder(radix)
430 comb += pe.i.eq(~pass_flags)
431 with m.If(~pe.n):
432 comb += next_bits.eq(pe.o-1)
433 with m.Else():
436 # get the highest passing rhs trial (indexed by next_bits)
437 ta = Array(trial_compare_rhs_values)
438 comb += self.o.compare_rhs.eq(ta[next_bits])
440 # create outputs for next phase
441 qr = self.i.quotient_root | (next_bits << current_shift)
443 << current_shift)
444 comb += self.o.quotient_root.eq(qr)
447 return m
450 class DivPipeCoreFinalStage(Elaboratable):
451 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
453 def __init__(self, core_config):
454 """ Create a ``DivPipeCoreFinalStage`` instance."""
455 self.core_config = core_config
456 self.i = self.ispec()
457 self.o = self.ospec()
459 def ispec(self):
460 """ Get the input spec for this pipeline stage."""
461 return DivPipeCoreInterstageData(self.core_config)
463 def ospec(self):
464 """ Get the output spec for this pipeline stage."""
465 return DivPipeCoreOutputData(self.core_config)
467 def setup(self, m, i):
468 """ Pipeline stage setup. """
469 m.submodules.div_pipe_core_final = self
470 m.d.comb += self.i.eq(i)
472 def process(self, i):
473 """ Pipeline stage process. """
474 return self.o # return processed data (ignore i)
476 def elaborate(self, platform):
477 """ Elaborate into ``Module``. """
478 m = Module()
479 comb = m.d.comb
481 comb += self.o.quotient_root.eq(self.i.quotient_root)
482 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
484 return m