The remainder is the left-hand-side of the comparison minus the
right-hand-side of the comparison in the above formulas.
"""
-from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
from nmigen.lib.coding import PriorityEncoder
+from nmutil.util import treereduce
import enum
+import operator
+
+
+class DivPipeCoreOperation(enum.Enum):
+ """ Operation for ``DivPipeCore``.
+
+ :attribute UDivRem: unsigned divide/remainder.
+ :attribute SqrtRem: square-root/remainder.
+ :attribute RSqrtRem: reciprocal-square-root/remainder.
+ """
+
+ SqrtRem = 0
+ UDivRem = 1
+ RSqrtRem = 2
+
+ def __int__(self):
+ """ Convert to int. """
+ return self.value
+
+ @classmethod
+ def create_signal(cls, *, src_loc_at=0, **kwargs):
+ """ Create a signal that can contain a ``DivPipeCoreOperation``. """
+ return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
+ src_loc_at=(src_loc_at + 1),
+ decoder=lambda v: str(cls(v)),
+ **kwargs)
+
+
+DP = DivPipeCoreOperation
class DivPipeCoreConfig:
computed per pipeline stage.
"""
- def __init__(self, bit_width, fract_width, log2_radix):
+ def __init__(self, bit_width, fract_width, log2_radix, supported=None):
""" Create a ``DivPipeCoreConfig`` instance. """
self.bit_width = bit_width
self.fract_width = fract_width
self.log2_radix = log2_radix
+ if supported is None:
+ supported = frozenset(DP)
+ else:
+ supported = frozenset(supported)
+ self.supported = supported
print(f"{self}: n_stages={self.n_stages}")
def __repr__(self):
""" Get repr. """
return f"DivPipeCoreConfig({self.bit_width}, " \
- + f"{self.fract_width}, {self.log2_radix})"
+ + f"{self.fract_width}, {self.log2_radix}, "\
+ + f"supported={self.supported})"
@property
def n_stages(self):
return (self.bit_width + self.log2_radix - 1) // self.log2_radix
-class DivPipeCoreOperation(enum.Enum):
- """ Operation for ``DivPipeCore``.
-
- :attribute UDivRem: unsigned divide/remainder.
- :attribute SqrtRem: square-root/remainder.
- :attribute RSqrtRem: reciprocal-square-root/remainder.
- """
-
- UDivRem = 0
- SqrtRem = 1
- RSqrtRem = 2
-
- def __int__(self):
- """ Convert to int. """
- return self.value
-
- @classmethod
- def create_signal(cls, *, src_loc_at=0, **kwargs):
- """ Create a signal that can contain a ``DivPipeCoreOperation``. """
- return Signal(min=min(map(int, cls)),
- max=max(map(int, cls)) + 2,
- src_loc_at=(src_loc_at + 1),
- decoder=lambda v: str(cls(v)),
- **kwargs)
-
-
-DP = DivPipeCoreOperation
-
-
class DivPipeCoreInputData:
""" input data type for ``DivPipeCore``.
def __init__(self, core_config, reset_less=True):
""" Create a ``DivPipeCoreInputData`` instance. """
self.core_config = core_config
- self.dividend = Signal(core_config.bit_width + core_config.fract_width,
- reset_less=reset_less)
- self.divisor_radicand = Signal(core_config.bit_width,
- reset_less=reset_less)
+ bw = core_config.bit_width
+ fw = core_config.fract_width
+ self.dividend = Signal(bw + fw, reset_less=reset_less)
+ self.divisor_radicand = Signal(bw, reset_less=reset_less)
self.operation = DP.create_signal(reset_less=reset_less)
def __iter__(self):
def __init__(self, core_config, reset_less=True):
""" Create a ``DivPipeCoreInterstageData`` instance. """
self.core_config = core_config
- self.divisor_radicand = Signal(core_config.bit_width,
- reset_less=reset_less)
+ bw = core_config.bit_width
+ # TODO(programmerjake): re-enable once bit_width reduction is fixed
+ if False and core_config.supported == {DP.UDivRem}:
+ self.compare_len = bw * 2
+ else:
+ self.compare_len = bw * 3
+ self.divisor_radicand = Signal(bw, reset_less=reset_less)
self.operation = DP.create_signal(reset_less=reset_less)
- self.quotient_root = Signal(core_config.bit_width,
- reset_less=reset_less)
- self.root_times_radicand = Signal(core_config.bit_width * 2,
- reset_less=reset_less)
- self.compare_lhs = Signal(core_config.bit_width * 3,
- reset_less=reset_less)
- self.compare_rhs = Signal(core_config.bit_width * 3,
- reset_less=reset_less)
+ self.quotient_root = Signal(bw, reset_less=reset_less)
+ self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
+ self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
+ self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
def __iter__(self):
""" Get member signals. """
def __init__(self, core_config, reset_less=True):
""" Create a ``DivPipeCoreOutputData`` instance. """
self.core_config = core_config
- self.quotient_root = Signal(core_config.bit_width,
- reset_less=reset_less)
- self.remainder = Signal(core_config.bit_width * 3,
- reset_less=reset_less)
+ bw = core_config.bit_width
+ # TODO(programmerjake): re-enable once bit_width reduction is fixed
+ if False and core_config.supported == {DP.UDivRem}:
+ self.compare_len = bw * 2
+ else:
+ self.compare_len = bw * 3
+ self.quotient_root = Signal(bw, reset_less=reset_less)
+ self.remainder = Signal(self.compare_len, reset_less=reset_less)
def __iter__(self):
""" Get member signals. """
self.core_config = core_config
self.i = self.ispec()
self.o = self.ospec()
+ bw = core_config.bit_width
+ # TODO(programmerjake): re-enable once bit_width reduction is fixed
+ if False and core_config.supported == {DP.UDivRem}:
+ self.compare_len = bw * 2
+ else:
+ self.compare_len = bw * 3
def ispec(self):
""" Get the input spec for this pipeline stage."""
comb += self.o.quotient_root.eq(0)
comb += self.o.root_times_radicand.eq(0)
- lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
+ lhs = Signal(self.compare_len, reset_less=True)
fw = self.core_config.fract_width
- with m.If(self.i.operation == int(DP.UDivRem)):
- comb += lhs.eq(self.i.dividend << fw)
- with m.Elif(self.i.operation == int(DP.SqrtRem)):
- comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
- with m.Else(): # DivPipeCoreOperation.RSqrtRem
- comb += lhs.eq(1 << (fw * 3))
+ with m.Switch(self.i.operation):
+ with m.Case(int(DP.UDivRem)):
+ comb += lhs.eq(self.i.dividend << fw)
+ with m.Case(int(DP.SqrtRem)):
+ comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
+ with m.Case(int(DP.RSqrtRem)):
+ comb += lhs.eq(1 << (fw * 3))
comb += self.o.compare_lhs.eq(lhs)
comb += self.o.compare_rhs.eq(0)
self.current_shift = current_shift
self.log2_radix = log2_radix
bw = core_config.bit_width
+ # TODO(programmerjake): re-enable once bit_width reduction is fixed
+ if False and core_config.supported == {DP.UDivRem}:
+ self.compare_len = bw * 2
+ else:
+ self.compare_len = bw * 3
self.divisor_radicand = Signal(bw, reset_less=True)
self.quotient_root = Signal(bw, reset_less=True)
self.root_times_radicand = Signal(bw * 2, reset_less=True)
- self.compare_rhs = Signal(bw * 3, reset_less=True)
- self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
+ self.compare_rhs = Signal(self.compare_len, reset_less=True)
+ self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
self.operation = DP.create_signal(reset_less=True)
def elaborate(self, platform):
m = Module()
comb = m.d.comb
+ cc = self.core_config
dr = self.divisor_radicand
- qr = self.quotient_root
- rr = self.root_times_radicand
trial_bits_sig = Const(self.trial_bits, self.log2_radix)
trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
self.log2_radix * 2)
tblen = self.core_config.bit_width+self.log2_radix
- tblen2 = self.core_config.bit_width+self.log2_radix*2
- dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
- comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
# UDivRem
- with m.If(self.operation == int(DP.UDivRem)):
- dr_times_trial_bits = Signal(tblen, reset_less=True)
- comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
- div_rhs = self.compare_rhs
+ if DP.UDivRem in cc.supported:
+ with m.If(self.operation == int(DP.UDivRem)):
+ dr_times_trial_bits = Signal(tblen, reset_less=True)
+ comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
+ div_rhs = self.compare_rhs
- div_term1 = dr_times_trial_bits
- div_term1_shift = self.core_config.fract_width
- div_term1_shift += self.current_shift
- div_rhs += div_term1 << div_term1_shift
+ div_term1 = dr_times_trial_bits
+ div_term1_shift = self.core_config.fract_width
+ div_term1_shift += self.current_shift
+ div_rhs += div_term1 << div_term1_shift
- comb += self.trial_compare_rhs.eq(div_rhs)
+ comb += self.trial_compare_rhs.eq(div_rhs)
# SqrtRem
- with m.Elif(self.operation == int(DP.SqrtRem)):
- qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
- comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
- sqrt_rhs = self.compare_rhs
-
- sqrt_term1 = qr_times_trial_bits
- sqrt_term1_shift = self.core_config.fract_width
- sqrt_term1_shift += self.current_shift + 1
- sqrt_rhs += sqrt_term1 << sqrt_term1_shift
- sqrt_term2 = trial_bits_sqrd_sig
- sqrt_term2_shift = self.core_config.fract_width
- sqrt_term2_shift += self.current_shift * 2
- sqrt_rhs += sqrt_term2 << sqrt_term2_shift
-
- comb += self.trial_compare_rhs.eq(sqrt_rhs)
+ if DP.SqrtRem in cc.supported:
+ with m.If(self.operation == int(DP.SqrtRem)):
+ qr = self.quotient_root
+ qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
+ comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
+ sqrt_rhs = self.compare_rhs
+
+ sqrt_term1 = qr_times_trial_bits
+ sqrt_term1_shift = self.core_config.fract_width
+ sqrt_term1_shift += self.current_shift + 1
+ sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+ sqrt_term2 = trial_bits_sqrd_sig
+ sqrt_term2_shift = self.core_config.fract_width
+ sqrt_term2_shift += self.current_shift * 2
+ sqrt_rhs += sqrt_term2 << sqrt_term2_shift
+
+ comb += self.trial_compare_rhs.eq(sqrt_rhs)
# RSqrtRem
- with m.Else():
- rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
- comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
- rsqrt_rhs = self.compare_rhs
-
- rsqrt_term1 = rr_times_trial_bits
- rsqrt_term1_shift = self.current_shift + 1
- rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
- rsqrt_term2 = dr_times_trial_bits_sqrd
- rsqrt_term2_shift = self.current_shift * 2
- rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
-
- comb += self.trial_compare_rhs.eq(rsqrt_rhs)
+ if DP.RSqrtRem in cc.supported:
+ with m.If(self.operation == int(DP.RSqrtRem)):
+ rr = self.root_times_radicand
+ tblen2 = self.core_config.bit_width+self.log2_radix*2
+ dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
+ comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
+ rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
+ comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
+ rsqrt_rhs = self.compare_rhs
+
+ rsqrt_term1 = rr_times_trial_bits
+ rsqrt_term1_shift = self.current_shift + 1
+ rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+ rsqrt_term2 = dr_times_trial_bits_sqrd
+ rsqrt_term2_shift = self.current_shift * 2
+ rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
+
+ comb += self.trial_compare_rhs.eq(rsqrt_rhs)
return m
def __init__(self, core_config, stage_index):
""" Create a ``DivPipeCoreSetupStage`` instance. """
- self.core_config = core_config
assert stage_index in range(core_config.n_stages)
+ self.core_config = core_config
+ bw = core_config.bit_width
+ # TODO(programmerjake): re-enable once bit_width reduction is fixed
+ if False and core_config.supported == {DP.UDivRem}:
+ self.compare_len = bw * 2
+ else:
+ self.compare_len = bw * 3
self.stage_index = stage_index
self.i = self.ispec()
self.o = self.ospec()
""" Elaborate into ``Module``. """
m = Module()
comb = m.d.comb
+ cc = self.core_config
# copy invariant inputs to outputs (for next stage)
comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
comb += t.compare_rhs.eq(self.i.compare_rhs)
comb += t.operation.eq(self.i.operation)
- # get the trial output
+ # get the trial output (needed even in pass_flags[0] case)
trial_compare_rhs_values.append(t.trial_compare_rhs)
# make the trial comparison against the [invariant] lhs.
# trial_compare_rhs is always decreasing as trial_bits increases
pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
- comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
+ if trial_bits == 0:
+ # do not do first comparison: no point.
+ comb += pass_flag.eq(1)
+ else:
+ comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
pfl.append(pass_flag)
# Cat all the pass flags list together (easier to handle, below)
with m.Else():
comb += next_bits.eq(radix-1)
- # get the highest passing rhs trial (indexed by next_bits)
- ta = Array(trial_compare_rhs_values)
- comb += self.o.compare_rhs.eq(ta[next_bits])
+ # get the highest passing rhs trial. use treereduce because
+ # Array on such massively long numbers is insanely gate-hungry
+ crhs = []
+ tcrh = trial_compare_rhs_values
+ for i in range(radix):
+ nbe = Signal(reset_less=True)
+ comb += nbe.eq(next_bits == i)
+ crhs.append(Repl(nbe, self.compare_len) & tcrh[i])
+ comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
+ lambda x:x))
# create outputs for next phase
qr = self.i.quotient_root | (next_bits << current_shift)
- rr = self.i.root_times_radicand + ((self.i.divisor_radicand * next_bits)
- << current_shift)
comb += self.o.quotient_root.eq(qr)
- comb += self.o.root_times_radicand.eq(rr)
+ if DP.RSqrtRem in cc.supported:
+ rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
+ next_bits) << current_shift)
+ comb += self.o.root_times_radicand.eq(rr)
return m