add rest of DivPipeCore
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux)
22 import enum
23
24 # TODO
25 #from ieee754.fpcommon.fpbase import FPNumBaseRecord
26 #from ieee754.fpcommon.getop import FPPipeContext
27
28
29 class DivPipeCoreConfig:
30 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
31
32 :attribute bit_width: base bit-width.
33 :attribute fract_width: base fract-width. Specifies location of base-2
34 radix point.
35 :attribute log2_radix: number of bits of ``quotient_root`` that should be
36 computed per pipeline stage.
37 """
38
39 def __init__(self, bit_width, fract_width, log2_radix):
40 """ Create a ``DivPipeCoreConfig`` instance. """
41 self.bit_width = bit_width
42 self.fract_width = fract_width
43 self.log2_radix = log2_radix
44
45 def __repr__(self):
46 """ Get repr. """
47 return f"DivPipeCoreConfig({self.bit_width}, " \
48 + f"{self.fract_width}, {self.log2_radix})"
49
50 @property
51 def num_calculate_stages(self):
52 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
53 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
54
55
56 class DivPipeCoreOperation(enum.IntEnum):
57 """ Operation for ``DivPipeCore``.
58
59 :attribute UDivRem: unsigned divide/remainder.
60 :attribute SqrtRem: square-root/remainder.
61 :attribute RSqrtRem: reciprocal-square-root/remainder.
62 """
63
64 UDivRem = 0
65 SqrtRem = 1
66 RSqrtRem = 2
67
68 @classmethod
69 def create_signal(cls, *, src_loc_at=0, **kwargs):
70 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
71 return Signal(min=int(min(cls)),
72 max=int(max(cls)),
73 src_loc_at=(src_loc_at + 1),
74 decoder=cls,
75 **kwargs)
76
77
78 class DivPipeCoreInputData:
79 """ input data type for ``DivPipeCore``.
80
81 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
82 configuration to be used.
83 :attribute dividend: dividend for div/rem. Signal with a bit-width of
84 ``core_config.bit_width + core_config.fract_width`` and a fract-width
85 of ``core_config.fract_width * 2`` bits.
86 :attribute divisor_radicand: divisor for div/rem and radicand for
87 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
88 fract-width of ``core_config.fract_width`` bits.
89 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
90 """
91
92 def __init__(self, core_config):
93 """ Create a ``DivPipeCoreInputData`` instance. """
94 self.core_config = core_config
95 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
96 reset_less=True)
97 self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
98
99 # FIXME: this goes into (is replaced by) self.ctx.op
100 self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
101
102 return # TODO: needs a width argument and a pspec
103 self.z = FPNumBaseRecord(width, False)
104 self.out_do_z = Signal(reset_less=True)
105 self.oz = Signal(width, reset_less=True)
106
107 self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
108 self.muxid = self.ctx.muxid # annoying. complicated.
109
110 def __iter__(self):
111 """ Get member signals. """
112 yield self.dividend
113 yield self.divisor_radicand
114 yield self.operation # FIXME: delete. already covered by self.ctx
115 return
116 yield self.z
117 yield self.out_do_z
118 yield self.oz
119 yield from self.ctx
120
121 def eq(self, rhs):
122 """ Assign member signals. """
123 return [self.dividend.eq(rhs.dividend),
124 self.divisor_radicand.eq(rhs.divisor_radicand),
125 self.operation.eq(rhs.operation)] # FIXME: delete.
126 # TODO: and these
127 return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
128 self.ctx.eq(i.ctx)]
129
130
131 class DivPipeCoreInterstageData:
132 """ interstage data type for ``DivPipeCore``.
133
134 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
135 configuration to be used.
136 :attribute divisor_radicand: divisor for div/rem and radicand for
137 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
138 fract-width of ``core_config.fract_width`` bits.
139 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
140 :attribute quotient_root: the quotient or root part of the result of the
141 operation. Signal with a bit-width of ``core_config.bit_width`` and a
142 fract-width of ``core_config.fract_width`` bits.
143 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
144 Signal with a bit-width of ``core_config.bit_width * 2`` and a
145 fract-width of ``core_config.fract_width * 2`` bits.
146 :attribute compare_lhs: The left-hand-side of the comparison in the
147 equation to be solved. Signal with a bit-width of
148 ``core_config.bit_width * 3`` and a fract-width of
149 ``core_config.fract_width * 3`` bits.
150 :attribute compare_rhs: The right-hand-side of the comparison in the
151 equation to be solved. Signal with a bit-width of
152 ``core_config.bit_width * 3`` and a fract-width of
153 ``core_config.fract_width * 3`` bits.
154 """
155
156 def __init__(self, core_config):
157 """ Create a ``DivPipeCoreInterstageData`` instance. """
158 self.core_config = core_config
159 self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
160 # XXX FIXME: delete. already covered by self.ctx.op
161 self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
162 self.quotient_root = Signal(core_config.bit_width, reset_less=True)
163 self.root_times_radicand = Signal(core_config.bit_width * 2,
164 reset_less=True)
165 self.compare_lhs = Signal(core_config.bit_width * 3, reset_less=True)
166 self.compare_rhs = Signal(core_config.bit_width * 3, reset_less=True)
167 return # TODO: needs a width argument and a pspec
168 self.z = FPNumBaseRecord(width, False)
169 self.out_do_z = Signal(reset_less=True)
170 self.oz = Signal(width, reset_less=True)
171
172 self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
173 self.muxid = self.ctx.muxid # annoying. complicated.
174
175 def __iter__(self):
176 """ Get member signals. """
177 yield self.divisor_radicand
178 yield self.operation # XXX FIXME: delete. already in self.ctx.op
179 yield self.quotient_root
180 yield self.root_times_radicand
181 yield self.compare_lhs
182 yield self.compare_rhs
183 return
184 yield self.z
185 yield self.out_do_z
186 yield self.oz
187 yield from self.ctx
188
189 def eq(self, rhs):
190 """ Assign member signals. """
191 return [self.divisor_radicand.eq(rhs.divisor_radicand),
192 self.operation.eq(rhs.operation), # FIXME: delete.
193 self.quotient_root.eq(rhs.quotient_root),
194 self.root_times_radicand.eq(rhs.root_times_radicand),
195 self.compare_lhs.eq(rhs.compare_lhs),
196 self.compare_rhs.eq(rhs.compare_rhs)]
197 # TODO: and these
198 return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
199 self.ctx.eq(i.ctx)]
200
201
202 class DivPipeCoreOutputData:
203 """ output data type for ``DivPipeCore``.
204
205 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
206 configuration to be used.
207 :attribute quotient_root: the quotient or root part of the result of the
208 operation. Signal with a bit-width of ``core_config.bit_width`` and a
209 fract-width of ``core_config.fract_width`` bits.
210 :attribute remainder: the remainder part of the result of the operation.
211 Signal with a bit-width of ``core_config.bit_width * 3`` and a
212 fract-width of ``core_config.fract_width * 3`` bits.
213 """
214
215 def __init__(self, core_config):
216 """ Create a ``DivPipeCoreOutputData`` instance. """
217 self.core_config = core_config
218 self.quotient_root = Signal(core_config.bit_width, reset_less=True)
219 self.remainder = Signal(core_config.bit_width * 3, reset_less=True)
220
221 def __iter__(self):
222 """ Get member signals. """
223 yield self.quotient_root
224 yield self.remainder
225 return
226
227 def eq(self, rhs):
228 """ Assign member signals. """
229 return [self.quotient_root.eq(rhs.quotient_root),
230 self.remainder.eq(rhs.remainder)]
231
232
233 class DivPipeCoreSetupStage(Elaboratable):
234 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
235
236 def __init__(self, core_config):
237 """ Create a ``DivPipeCoreSetupStage`` instance."""
238 self.core_config = core_config
239 self.i = self.ispec()
240 self.o = self.ospec()
241
242 def ispec(self):
243 """ Get the input spec for this pipeline stage."""
244 return DivPipeCoreInputData(self.core_config)
245
246 def ospec(self):
247 """ Get the output spec for this pipeline stage."""
248 return DivPipeCoreInterstageData(self.core_config)
249
250 def setup(self, m, i):
251 """ Pipeline stage setup. """
252 m.submodules.div_pipe_core_setup = self
253 m.d.comb += self.i.eq(i)
254
255 def process(self, i):
256 """ Pipeline stage process. """
257 return self.o # return processed data (ignore i)
258
259 def elaborate(self, platform):
260 """ Elaborate into ``Module``. """
261 m = Module()
262
263 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
264 m.d.comb += self.o.quotient_root.eq(0)
265 m.d.comb += self.o.root_times_radicand.eq(0)
266
267 with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
268 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
269 << self.core_config.fract_width)
270 with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
271 m.d.comb += self.o.compare_lhs.eq(
272 self.i.divisor_radicand << (self.core_config.fract_width * 2))
273 with m.Else(): # DivPipeCoreOperation.RSqrtRem
274 m.d.comb += self.o.compare_lhs.eq(
275 1 << (self.core_config.fract_width * 3))
276
277 m.d.comb += self.o.compare_rhs.eq(0)
278 m.d.comb += self.o.operation.eq(self.i.operation)
279
280 return m
281
282 # TODO: these as well
283 m.d.comb += self.o.oz.eq(self.i.oz)
284 m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
285 m.d.comb += self.o.ctx.eq(self.i.ctx)
286
287
288 class DivPipeCoreCalculateStage(Elaboratable):
289 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
290
291 def __init__(self, core_config, stage_index):
292 """ Create a ``DivPipeCoreSetupStage`` instance. """
293 self.core_config = core_config
294 assert stage_index in range(core_config.num_calculate_stages)
295 self.stage_index = stage_index
296 self.i = self.ispec()
297 self.o = self.ospec()
298
299 def ispec(self):
300 """ Get the input spec for this pipeline stage. """
301 return DivPipeCoreInterstageData(self.core_config)
302
303 def ospec(self):
304 """ Get the output spec for this pipeline stage. """
305 return DivPipeCoreInterstageData(self.core_config)
306
307 def setup(self, m, i):
308 """ Pipeline stage setup. """
309 setattr(m.submodules,
310 f"div_pipe_core_calculate_{self.stage_index}",
311 self)
312 m.d.comb += self.i.eq(i)
313
314 def process(self, i):
315 """ Pipeline stage process. """
316 return self.o
317
318 def elaborate(self, platform):
319 """ Elaborate into ``Module``. """
320 m = Module()
321 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
322 m.d.comb += self.o.operation.eq(self.i.operation)
323 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
324 log2_radix = self.core_config.log2_radix
325 current_shift = self.core_config.bit_width
326 current_shift -= self.stage_index * log2_radix
327 log2_radix = min(log2_radix, current_shift)
328 assert log2_radix > 0
329 current_shift -= log2_radix
330 radix = 1 << log2_radix
331 trial_compare_rhs_values = []
332 pass_flags = []
333 for trial_bits in range(radix):
334 shifted_trial_bits = Const(trial_bits, log2_radix) << current_shift
335 shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
336
337 # UDivRem
338 div_rhs = self.i.compare_rhs
339 div_factor1 = self.i.divisor_radicand * shifted_trial_bits
340 div_rhs += div_factor1 << self.core_config.fract_width
341
342 # SqrtRem
343 sqrt_rhs = self.i.compare_rhs
344 sqrt_factor1 = self.i.quotient_root * (shifted_trial_bits << 1)
345 sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
346 sqrt_factor2 = shifted_trial_bits_sqrd
347 sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
348
349 # RSqrtRem
350 rsqrt_rhs = self.i.compare_rhs
351 rsqrt_rhs += self.i.root_times_radicand * (shifted_trial_bits << 1)
352 rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
353
354 trial_compare_rhs = self.o.compare_rhs.like(
355 name=f"trial_compare_rhs_{trial_bits}")
356
357 with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
358 m.d.comb += trial_compare_rhs.eq(div_rhs)
359 with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
360 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
361 with m.Else(): # DivPipeCoreOperation.RSqrtRem
362 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
363 trial_compare_rhs_values.append(trial_compare_rhs)
364
365 pass_flag = Signal(name=f"pass_flag_{trial_bits}")
366 m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
367 pass_flags.append(pass_flag)
368
369 # convert pass_flags to next_bits.
370 #
371 # Assumes that for each set bit in pass_flag, all previous bits are
372 # also set.
373 #
374 # Assumes that pass_flag[0] is always set (since
375 # compare_lhs >= compare_rhs is a pipeline invariant).
376
377 next_bits = Signal(log2_radix)
378 for i in range(log2_radix):
379 bit_value = 1
380 for j in range(0, radix, 1 << i):
381 bit_value ^= pass_flags[j]
382 m.d.comb += next_bits.part(i, 1).eq(bit_value)
383
384 next_compare_rhs = 0
385 for i in range(radix):
386 next_flag = pass_flags[i + 1] if i + 1 < radix else 0
387 next_compare_rhs |= Mux(pass_flags[i] & ~next_flag,
388 trial_compare_rhs_values[i],
389 0)
390
391 m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
392 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
393 + ((self.i.divisor_radicand
394 * next_bits)
395 << current_shift))
396 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
397 | (next_bits << current_shift))
398 return m
399
400
401 class DivPipeCoreFinalStage(Elaboratable):
402 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
403
404 def __init__(self, core_config):
405 """ Create a ``DivPipeCoreFinalStage`` instance."""
406 self.core_config = core_config
407 self.i = self.ispec()
408 self.o = self.ospec()
409
410 def ispec(self):
411 """ Get the input spec for this pipeline stage."""
412 return DivPipeCoreInterstageData(self.core_config)
413
414 def ospec(self):
415 """ Get the output spec for this pipeline stage."""
416 return DivPipeCoreOutputData(self.core_config)
417
418 def setup(self, m, i):
419 """ Pipeline stage setup. """
420 m.submodules.div_pipe_core_setup = self
421 m.d.comb += self.i.eq(i)
422
423 def process(self, i):
424 """ Pipeline stage process. """
425 return self.o # return processed data (ignore i)
426
427 def elaborate(self, platform):
428 """ Elaborate into ``Module``. """
429 m = Module()
430
431 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
432 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
433 - self.i.compare_rhs)
434
435 return m