computes `z = (a * c) + b` but only rounds once at the end
"""
-from nmutil.pipemodbase import PipeModBase
+from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
from ieee754.fpcommon.basedata import FPBaseData
from nmigen.hdl.ast import Signal
from nmigen.hdl.dsl import Module
from ieee754.fpcommon.getop import FPPipeContext
from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge
from ieee754.fpfma.util import expanded_exponent_shape, \
- expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape
+ expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape, \
+ EXPANDED_MANTISSA_EXTRA_MSBS, EXPANDED_MANTISSA_EXTRA_LSBS, \
+ product_mantissa_shape
class FPFMAInputData(FPBaseData):
self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
"""mantissa of a input -- un-normalized and with implicit bit added"""
- self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+ self.b_mantissa = Signal(expanded_mantissa_shape(fpf))
"""mantissa of b input
shifted to appropriate location for add and with implicit bit added
"""
- self.c_mantissa = Signal(expanded_mantissa_shape(fpf))
+ self.c_mantissa = Signal(multiplicand_mantissa_shape(fpf))
"""mantissa of c input -- un-normalized and with implicit bit added"""
self.do_sub = Signal()
out = self.o
a_exponent = Signal(expanded_exponent_shape(fpf))
- m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a))
+ m.d.comb += a_exponent.eq(fpf.get_exponent_value(inp.a))
b_exponent_in = Signal(expanded_exponent_shape(fpf))
- m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b))
+ m.d.comb += b_exponent_in.eq(fpf.get_exponent_value(inp.b))
c_exponent = Signal(expanded_exponent_shape(fpf))
- m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c))
+ m.d.comb += c_exponent.eq(fpf.get_exponent_value(inp.c))
+ b_exponent = Signal(expanded_exponent_shape(fpf))
+ m.d.comb += b_exponent.eq(b_exponent_in + EXPANDED_MANTISSA_EXTRA_MSBS)
prod_exponent = Signal(expanded_exponent_shape(fpf))
- m.d.comb += prod_exponent.eq(a_exponent + c_exponent)
+
+ # number of bits that the product of two normalized signals needs to
+ # be shifted left to be normalized, e.g. the product of 2 8-bit
+ # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted
+ # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again:
+ # `0x4000 << 1 == 0x8000`
+ PROD_STAY_NORM_SHIFT = 1
+
+ extra_prod_exponent = (expanded_mantissa_shape(fpf).width
+ - product_mantissa_shape(fpf).width
+ + PROD_STAY_NORM_SHIFT
+ - EXPANDED_MANTISSA_EXTRA_LSBS)
+ m.d.comb += prod_exponent.eq(a_exponent + c_exponent
+ + extra_prod_exponent)
prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
- m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in)
+ m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent)
b_mantissa_in = Signal(fpf.fraction_width + 1)
m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
p_sign = Signal()
]
with m.Else():
m.d.comb += [
- exponent.eq(b_exponent_in),
+ exponent.eq(b_exponent),
b_shift.eq(0),
]
- m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width)
+ m.submodules.rshiftm = rshiftm = MultiShiftRMerge(
+ out.b_mantissa.width - EXPANDED_MANTISSA_EXTRA_MSBS,
+ s_max=expanded_exponent_shape(fpf).width - 1)
m.d.comb += [
- rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width
- - b_mantissa_in.width)),
+ rshiftm.inp.eq(0),
+ rshiftm.inp[-b_mantissa_in.width:].eq(b_mantissa_in),
rshiftm.diff.eq(b_shift),
]
+ keep = {"keep": True}
+
# handle special cases
with m.If(fpf.is_nan(inp.a)):
m.d.comb += [
+ Signal(name="case_nan_a", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
out.do_bypass.eq(True),
]
with m.Elif(fpf.is_nan(inp.b)):
m.d.comb += [
+ Signal(name="case_nan_b", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
out.do_bypass.eq(True),
]
with m.Elif(fpf.is_nan(inp.c)):
m.d.comb += [
+ Signal(name="case_nan_c", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
out.do_bypass.eq(True),
]
| (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
# infinity * 0
m.d.comb += [
+ Signal(name="case_inf_times_zero", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.quiet_nan()),
out.do_bypass.eq(True),
]
with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
- & fpf.is_inf(inp.b) & p_sign != b_sign):
+ & fpf.is_inf(inp.b) & (p_sign != b_sign)):
# inf - inf
m.d.comb += [
+ Signal(name="case_inf_minus_inf", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.quiet_nan()),
out.do_bypass.eq(True),
]
with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
# inf + x
m.d.comb += [
+ Signal(name="case_inf_plus_x", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.inf(p_sign)),
out.do_bypass.eq(True),
]
with m.Elif(fpf.is_inf(inp.b)):
# x + inf
m.d.comb += [
+ Signal(name="case_x_plus_inf", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.inf(b_sign)),
out.do_bypass.eq(True),
]
with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
- & fpf.is_zero(inp.b) & p_sign == b_sign):
+ & fpf.is_zero(inp.b) & (p_sign == b_sign)):
# zero + zero
m.d.comb += [
+ Signal(name="case_zero_plus_zero", attrs=keep).eq(True),
out.bypassed_z.eq(fpf.zero(p_sign)),
out.do_bypass.eq(True),
]
- # zero - zero handled by FPFMAMainStage
+ with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
+ & ~fpf.is_zero(inp.b)):
+ # zero + x
+ m.d.comb += [
+ Signal(name="case_zero_plus_x", attrs=keep).eq(True),
+ out.bypassed_z.eq(inp.b),
+ out.do_bypass.eq(True),
+ ]
with m.Else():
+ # zero - zero handled by FPFMAMainStage
m.d.comb += [
out.bypassed_z.eq(0),
out.do_bypass.eq(False),
]
return m
+
+
+class FPFMASpecialCasesDeNormStage(PipeModBaseChain):
+ def __init__(self, pspec):
+ super().__init__(pspec)
+
+ def get_chain(self):
+ """ gets chain of modules
+ """
+ return [FPFMASpecialCasesDeNorm(self.pspec)]