"""
-from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array, Value
+from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
+ Value, Shape)
from math import log
from operator import or_
from functools import reduce
"""
return x & self.mantissa_mask
+ def get_mantissa_value(self, x):
+ """ returns the mantissa of its input number, x, but with the
+ implicit bit, if any, made explicit.
+ """
+ if self.has_int_bit:
+ return self.get_mantissa_field(x)
+ exponent_field = self.get_exponent_field(x)
+ mantissa_field = self.get_mantissa_field(x)
+ implicit_bit = exponent_field == self.exponent_denormal_zero
+ return (implicit_bit << self.fraction_width) | mantissa_field
+
def is_zero(self, x):
""" returns true if x is +/- zero
"""
(self.get_mantissa_field(x) != 0) & \
(self.get_mantissa_field(x) & highbit != 0)
+ def to_quiet_nan(self, x):
+ """ converts `x` to a quiet NaN """
+ highbit = 1 << (self.m_width - 1)
+ return x | highbit | self.exponent_mask
+
+ def quiet_nan(self, sign=0):
+ """ return the default quiet NaN with sign `sign` """
+ return self.to_quiet_nan(self.zero(sign))
+
+ def zero(self, sign=0):
+ """ return zero with sign `sign` """
+ return (sign != 0) << (self.e_width + self.m_width)
+
+ def inf(self, sign=0):
+ """ return infinity with sign `sign` """
+ return self.zero(sign) | self.exponent_mask
+
def is_nan_signaling(self, x):
""" returns true if x is a signalling nan
"""
""" Get a mantissa mask based on the mantissa width """
return (1 << self.m_width) - 1
+ @property
+ def exponent_mask(self):
+ """ Get an exponent mask """
+ return self.exponent_inf_nan << self.m_width
+
@property
def exponent_inf_nan(self):
""" Get the value of the exponent field designating infinity/NaN. """
def __init__(self, width, s_max=None):
if s_max is None:
s_max = int(log(width) / log(2))
- self.smax = s_max
+ self.smax = Shape.cast(s_max)
self.m = Signal(width, reset_less=True)
self.inp = Signal(width, reset_less=True)
self.diff = Signal(s_max, reset_less=True)
smask = Signal(self.width, reset_less=True)
stickybit = Signal(reset_less=True)
# XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
- maxslen = Signal(self.smax[0], reset_less=True)
- maxsleni = Signal(self.smax[0], reset_less=True)
+ maxslen = Signal(self.smax.width, reset_less=True)
+ maxsleni = Signal(self.smax.width, reset_less=True)
sm = MultiShift(self.width-1)
m0s = Const(0, self.width-1)
--- /dev/null
+""" floating-point fused-multiply-add
+
+computes `z = (a * c) + b` but only rounds once at the end
+"""
+
+from nmutil.pipemodbase import PipeModBase
+from ieee754.fpcommon.fpbase import FPRoundingMode
+from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ast import Signal, signed, unsigned, Mux
+from ieee754.fpfma.util import expanded_exponent_shape, \
+ expanded_mantissa_shape, get_fpformat
+from ieee754.fpcommon.getop import FPPipeContext
+
+
+class FPFMAPostCalcData:
+ def __init__(self, pspec):
+ fpf = get_fpformat(pspec)
+
+ self.sign = Signal()
+ """sign"""
+
+ self.exponent = Signal(expanded_exponent_shape(fpf))
+ """exponent -- unbiased"""
+
+ self.mantissa = Signal(expanded_mantissa_shape(fpf))
+ """unnormalized mantissa"""
+
+ self.bypassed_z = Signal(fpf.width)
+ """final output value of the fma when `do_bypass` is set"""
+
+ self.do_bypass = Signal()
+ """set if `bypassed_z` is the final output value of the fma"""
+
+ self.ctx = FPPipeContext(pspec)
+ """pipe context"""
+
+ self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
+ """rounding mode"""
+
+
+class FPFMAMainStage(PipeModBase):
+ def __init__(self, pspec):
+ super().__init__(pspec, "main")
+
+ def ispec(self):
+ return FPFMASpecialCasesDeNormOutData(self.pspec)
+
+ def ospec(self):
+ return FPFMAPostCalcData(self.pspec)
+
+ def elaborate(self, platform):
+ m = Module()
+ fpf = get_fpformat(self.pspec)
+ assert fpf.has_sign
+ inp = self.i
+ out = self.o
+
+ product_v = inp.a_mantissa * inp.c_mantissa
+ product = Signal(product_v.shape())
+ m.d.comb += product.eq(product_v)
+ negate_b_s = Signal(signed(1))
+ negate_b_u = Signal(unsigned(1))
+ m.d.comb += [
+ negate_b_s.eq(inp.do_sub),
+ negate_b_u.eq(inp.do_sub),
+ ]
+ sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u
+ sum = Signal(sum_v.shape())
+ m.d.comb += sum.eq(sum_v)
+
+ sum_neg = Signal()
+ sum_zero = Signal()
+ m.d.comb += [
+ sum_neg.eq(sum < 0), # just sign bit
+ sum_zero.eq(sum == 0),
+ ]
+
+ zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
+
+ with m.If(sum_zero & ~inp.do_bypass):
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.zero(zero_sign_array[inp.rm])),
+ out.do_bypass.eq(True),
+ ]
+ with m.Else():
+ m.d.comb += [
+ out.bypassed_z.eq(inp.bypassed_z),
+ out.do_bypass.eq(inp.do_bypass),
+ ]
+
+ m.d.comb += [
+ out.sign.eq(sum_neg ^ inp.sign),
+ out.exponent.eq(inp.exponent),
+ out.mantissa.eq(Mux(sum_neg, -sum, sum)),
+ out.ctx.eq(inp.ctx),
+ out.rm.eq(inp.rm),
+ ]
+ return m
--- /dev/null
+from nmutil.pipemodbase import PipeModBaseChain, PipeModBase
+from ieee754.fpcommon.postnormalise import FPNorm1Data
+from ieee754.fpcommon.roundz import FPRoundMod
+from ieee754.fpcommon.corrections import FPCorrectionsMod
+from ieee754.fpcommon.pack import FPPackMod
+from ieee754.fpfma.main_stage import FPFMAPostCalcData
+from nmigen.hdl.dsl import Module
+
+from ieee754.fpfma.util import get_fpformat
+
+
+class FPFMANorm(PipeModBase):
+ def __init__(self, pspec):
+ super().__init__(pspec, "norm")
+
+ def ispec(self):
+ return FPFMAPostCalcData(self.pspec)
+
+ def ospec(self):
+ return FPNorm1Data(self.pspec)
+
+ def elaborate(self, platform):
+ m = Module()
+ fpf = get_fpformat(self.pspec)
+ assert fpf.has_sign
+ inp = self.i
+ out = self.o
+ raise NotImplementedError # FIXME: finish
+ m.d.comb += [
+ out.roundz.eq(),
+ out.z.eq(),
+ out.out_do_z.eq(),
+ out.oz.eq(),
+ out.ctx.eq(),
+ out.rm.eq(),
+ ]
+ return m
+
+
+class FPFMANormToPack(PipeModBaseChain):
+ def __init__(self, pspec):
+ super().__init__(pspec)
+
+ def get_chain(self):
+ """ gets chain of modules
+ """
+ nmod = FPFMANorm(self.pspec)
+ rmod = FPRoundMod(self.pspec)
+ cmod = FPCorrectionsMod(self.pspec)
+ pmod = FPPackMod(self.pspec)
+ return [nmod, rmod, cmod, pmod]
--- /dev/null
+""" floating-point fused-multiply-add
+
+computes `z = (a * c) + b` but only rounds once at the end
+"""
+
+from nmutil.singlepipe import ControlBase
+from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNorm
+from ieee754.fpfma.main_stage import FPFMAMainStage
+from ieee754.fpfma.norm import FPFMANormToPack
+
+
+class FPFMABasePipe(ControlBase):
+ def __init__(self, pspec):
+ super().__init__()
+ self.sc_denorm = FPFMASpecialCasesDeNorm(pspec)
+ self.main = FPFMAMainStage(pspec)
+ self.normpack = FPFMANormToPack(pspec)
+ self._eqs = self.connect([self.sc_denorm, self.main, self.normpack])
+
+ def elaborate(self, platform):
+ m = super().elaborate(platform)
+ m.submodules.sc_denorm = self.sc_denorm
+ m.submodules.main = self.main
+ m.submodules.normpack = self.normpack
+ m.d.comb += self._eqs
+ return m
--- /dev/null
+""" floating-point fused-multiply-add
+
+computes `z = (a * c) + b` but only rounds once at the end
+"""
+
+from nmutil.pipemodbase import PipeModBase
+from ieee754.fpcommon.basedata import FPBaseData
+from nmigen.hdl.ast import Signal
+from nmigen.hdl.dsl import Module
+from ieee754.fpcommon.getop import FPPipeContext
+from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge
+from ieee754.fpfma.util import expanded_exponent_shape, \
+ expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape
+
+
+class FPFMAInputData(FPBaseData):
+ def __init__(self, pspec):
+ assert pspec.n_ops == 3
+ super().__init__(pspec)
+
+ self.negate_addend = Signal()
+ """if the addend should be negated"""
+
+ self.negate_product = Signal()
+ """if the product should be negated"""
+
+ def eq(self, i):
+ ret = super().eq(i)
+ ret.append(self.negate_addend.eq(i.negate_addend))
+ ret.append(self.negate_product.eq(i.negate_product))
+ return ret
+
+ def __iter__(self):
+ yield from super().__iter__()
+ yield self.negate_addend
+ yield self.negate_product
+
+ def ports(self):
+ return list(self)
+
+
+class FPFMASpecialCasesDeNormOutData:
+ def __init__(self, pspec):
+ fpf = get_fpformat(pspec)
+
+ self.sign = Signal()
+ """sign"""
+
+ self.exponent = Signal(expanded_exponent_shape(fpf))
+ """exponent of intermediate -- unbiased"""
+
+ self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+ """mantissa of a input -- un-normalized and with implicit bit added"""
+
+ self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+ """mantissa of b input
+
+ shifted to appropriate location for add and with implicit bit added
+ """
+
+ self.c_mantissa = Signal(expanded_mantissa_shape(fpf))
+ """mantissa of c input -- un-normalized and with implicit bit added"""
+
+ self.do_sub = Signal()
+ """true if `b_mantissa` should be subtracted from
+ `a_mantissa * c_mantissa` rather than added
+ """
+
+ self.bypassed_z = Signal(fpf.width)
+ """final output value of the fma when `do_bypass` is set"""
+
+ self.do_bypass = Signal()
+ """set if `bypassed_z` is the final output value of the fma"""
+
+ self.ctx = FPPipeContext(pspec)
+ """pipe context"""
+
+ self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
+ """rounding mode"""
+
+ def __iter__(self):
+ yield self.sign
+ yield self.exponent
+ yield self.a_mantissa
+ yield self.b_mantissa
+ yield self.c_mantissa
+ yield self.do_sub
+ yield self.bypassed_z
+ yield self.do_bypass
+ yield from self.ctx
+ yield self.rm
+
+ def eq(self, i):
+ return [
+ self.sign.eq(i.sign),
+ self.exponent.eq(i.exponent),
+ self.a_mantissa.eq(i.a_mantissa),
+ self.b_mantissa.eq(i.b_mantissa),
+ self.c_mantissa.eq(i.c_mantissa),
+ self.do_sub.eq(i.do_sub),
+ self.bypassed_z.eq(i.bypassed_z),
+ self.do_bypass.eq(i.do_bypass),
+ self.ctx.eq(i.ctx),
+ self.rm.eq(i.rm),
+ ]
+
+
+class FPFMASpecialCasesDeNorm(PipeModBase):
+ def __init__(self, pspec):
+ super().__init__(pspec, "sc_denorm")
+
+ def ispec(self):
+ return FPFMAInputData(self.pspec)
+
+ def ospec(self):
+ return FPFMASpecialCasesDeNormOutData(self.pspec)
+
+ def elaborate(self, platform):
+ m = Module()
+ fpf = get_fpformat(self.pspec)
+ assert fpf.has_sign
+ inp = self.i
+ out = self.o
+
+ a_exponent = Signal(expanded_exponent_shape(fpf))
+ m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a))
+ b_exponent_in = Signal(expanded_exponent_shape(fpf))
+ m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b))
+ c_exponent = Signal(expanded_exponent_shape(fpf))
+ m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c))
+ prod_exponent = Signal(expanded_exponent_shape(fpf))
+ m.d.comb += prod_exponent.eq(a_exponent + c_exponent)
+ prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
+ m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in)
+ b_mantissa_in = Signal(fpf.fraction_width + 1)
+ m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
+ p_sign = Signal()
+ m.d.comb += p_sign.eq(fpf.get_sign_field(inp.a) ^
+ fpf.get_sign_field(inp.c) ^ inp.negate_product)
+ b_sign = Signal()
+ m.d.comb += b_sign.eq(fpf.get_sign_field(inp.b) ^ inp.negate_addend)
+
+ exponent = Signal(expanded_exponent_shape(fpf))
+ b_shift = Signal(expanded_exponent_shape(fpf))
+ # use >= since that's just checking the sign bit
+ with m.If(prod_exp_minus_b_exp >= 0):
+ m.d.comb += [
+ exponent.eq(prod_exponent),
+ b_shift.eq(prod_exp_minus_b_exp),
+ ]
+ with m.Else():
+ m.d.comb += [
+ exponent.eq(b_exponent_in),
+ b_shift.eq(0),
+ ]
+
+ m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width)
+ m.d.comb += [
+ rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width
+ - b_mantissa_in.width)),
+ rshiftm.diff.eq(b_shift),
+ ]
+
+ # handle special cases
+ with m.If(fpf.is_nan(inp.a)):
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif(fpf.is_nan(inp.b)):
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif(fpf.is_nan(inp.c)):
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif((fpf.is_zero(inp.a) & fpf.is_inf(inp.c))
+ | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
+ # infinity * 0
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.quiet_nan()),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
+ & fpf.is_inf(inp.b) & p_sign != b_sign):
+ # inf - inf
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.quiet_nan()),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
+ # inf + x
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.inf(p_sign)),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif(fpf.is_inf(inp.b)):
+ # x + inf
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.inf(b_sign)),
+ out.do_bypass.eq(True),
+ ]
+ with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
+ & fpf.is_zero(inp.b) & p_sign == b_sign):
+ # zero + zero
+ m.d.comb += [
+ out.bypassed_z.eq(fpf.zero(p_sign)),
+ out.do_bypass.eq(True),
+ ]
+ # zero - zero handled by FPFMAMainStage
+ with m.Else():
+ m.d.comb += [
+ out.bypassed_z.eq(0),
+ out.do_bypass.eq(False),
+ ]
+
+ m.d.comb += [
+ out.sign.eq(p_sign),
+ out.exponent.eq(exponent),
+ out.a_mantissa.eq(fpf.get_mantissa_value(inp.a)),
+ out.b_mantissa.eq(rshiftm.m),
+ out.c_mantissa.eq(fpf.get_mantissa_value(inp.c)),
+ out.do_sub.eq(p_sign != b_sign),
+ out.ctx.eq(inp.ctx),
+ out.rm.eq(inp.rm),
+ ]
+
+ return m
--- /dev/null
+from ieee754.fpcommon.fpbase import FPFormat
+from nmigen.hdl.ast import signed, unsigned
+
+
+def expanded_exponent_shape(fpformat):
+ assert isinstance(fpformat, FPFormat)
+ return signed(fpformat.e_width + 3)
+
+
+EXPANDED_MANTISSA_EXTRA_LSBS = 3
+
+
+def expanded_mantissa_shape(fpformat):
+ assert isinstance(fpformat, FPFormat)
+ return signed(fpformat.fraction_width * 3 +
+ 2 + EXPANDED_MANTISSA_EXTRA_LSBS)
+
+
+def multiplicand_mantissa_shape(fpformat):
+ assert isinstance(fpformat, FPFormat)
+ return unsigned(fpformat.fraction_width + 1)
+
+
+def product_mantissa_shape(fpformat):
+ assert isinstance(fpformat, FPFormat)
+ return unsigned(multiplicand_mantissa_shape(fpformat).width * 2)
+
+
+def get_fpformat(pspec):
+ width = pspec.width
+ assert isinstance(width, int)
+ fpformat = getattr(pspec, "fpformat", None)
+ if fpformat is None:
+ fpformat = FPFormat.standard(width)
+ else:
+ assert isinstance(fpformat, FPFormat)
+ assert width == fpformat.width
+ return fpformat