X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fieee754%2Ffpfma%2Fspecial_cases.py;h=826c32a8e80f2a121a585d7c83bb65e436e97d50;hb=449176c8896dd13ae80130a3a0c8fc88026a2499;hp=95d3026692465afeba1bf81955b87c3123b9e779;hpb=da903f9bc46de51cae1f4a8fd9e9ea2d83bd572a;p=ieee754fpu.git diff --git a/src/ieee754/fpfma/special_cases.py b/src/ieee754/fpfma/special_cases.py index 95d30266..826c32a8 100644 --- a/src/ieee754/fpfma/special_cases.py +++ b/src/ieee754/fpfma/special_cases.py @@ -3,14 +3,16 @@ computes `z = (a * c) + b` but only rounds once at the end """ -from nmutil.pipemodbase import PipeModBase +from nmutil.pipemodbase import PipeModBase, PipeModBaseChain from ieee754.fpcommon.basedata import FPBaseData from nmigen.hdl.ast import Signal from nmigen.hdl.dsl import Module from ieee754.fpcommon.getop import FPPipeContext from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge from ieee754.fpfma.util import expanded_exponent_shape, \ - expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape + expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape, \ + EXPANDED_MANTISSA_EXTRA_MSBS, EXPANDED_MANTISSA_EXTRA_LSBS, \ + product_mantissa_shape class FPFMAInputData(FPBaseData): @@ -52,13 +54,13 @@ class FPFMASpecialCasesDeNormOutData: self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf)) """mantissa of a input -- un-normalized and with implicit bit added""" - self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf)) + self.b_mantissa = Signal(expanded_mantissa_shape(fpf)) """mantissa of b input shifted to appropriate location for add and with implicit bit added """ - self.c_mantissa = Signal(expanded_mantissa_shape(fpf)) + self.c_mantissa = Signal(multiplicand_mantissa_shape(fpf)) """mantissa of c input -- un-normalized and with implicit bit added""" self.do_sub = Signal() @@ -123,15 +125,30 @@ class FPFMASpecialCasesDeNorm(PipeModBase): out = self.o a_exponent = Signal(expanded_exponent_shape(fpf)) - m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a)) + m.d.comb += a_exponent.eq(fpf.get_exponent_value(inp.a)) b_exponent_in = Signal(expanded_exponent_shape(fpf)) - m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b)) + m.d.comb += b_exponent_in.eq(fpf.get_exponent_value(inp.b)) c_exponent = Signal(expanded_exponent_shape(fpf)) - m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c)) + m.d.comb += c_exponent.eq(fpf.get_exponent_value(inp.c)) + b_exponent = Signal(expanded_exponent_shape(fpf)) + m.d.comb += b_exponent.eq(b_exponent_in + EXPANDED_MANTISSA_EXTRA_MSBS) prod_exponent = Signal(expanded_exponent_shape(fpf)) - m.d.comb += prod_exponent.eq(a_exponent + c_exponent) + + # number of bits that the product of two normalized signals needs to + # be shifted left to be normalized, e.g. the product of 2 8-bit + # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted + # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again: + # `0x4000 << 1 == 0x8000` + PROD_STAY_NORM_SHIFT = 1 + + extra_prod_exponent = (expanded_mantissa_shape(fpf).width + - product_mantissa_shape(fpf).width + + PROD_STAY_NORM_SHIFT + - EXPANDED_MANTISSA_EXTRA_LSBS) + m.d.comb += prod_exponent.eq(a_exponent + c_exponent + + extra_prod_exponent) prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf)) - m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in) + m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent) b_mantissa_in = Signal(fpf.fraction_width + 1) m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b)) p_sign = Signal() @@ -150,30 +167,37 @@ class FPFMASpecialCasesDeNorm(PipeModBase): ] with m.Else(): m.d.comb += [ - exponent.eq(b_exponent_in), + exponent.eq(b_exponent), b_shift.eq(0), ] - m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width) + m.submodules.rshiftm = rshiftm = MultiShiftRMerge( + out.b_mantissa.width - EXPANDED_MANTISSA_EXTRA_MSBS, + s_max=expanded_exponent_shape(fpf).width - 1) m.d.comb += [ - rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width - - b_mantissa_in.width)), + rshiftm.inp.eq(0), + rshiftm.inp[-b_mantissa_in.width:].eq(b_mantissa_in), rshiftm.diff.eq(b_shift), ] + keep = {"keep": True} + # handle special cases with m.If(fpf.is_nan(inp.a)): m.d.comb += [ + Signal(name="case_nan_a", attrs=keep).eq(True), out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)), out.do_bypass.eq(True), ] with m.Elif(fpf.is_nan(inp.b)): m.d.comb += [ + Signal(name="case_nan_b", attrs=keep).eq(True), out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)), out.do_bypass.eq(True), ] with m.Elif(fpf.is_nan(inp.c)): m.d.comb += [ + Signal(name="case_nan_c", attrs=keep).eq(True), out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)), out.do_bypass.eq(True), ] @@ -181,37 +205,50 @@ class FPFMASpecialCasesDeNorm(PipeModBase): | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))): # infinity * 0 m.d.comb += [ + Signal(name="case_inf_times_zero", attrs=keep).eq(True), out.bypassed_z.eq(fpf.quiet_nan()), out.do_bypass.eq(True), ] with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c)) - & fpf.is_inf(inp.b) & p_sign != b_sign): + & fpf.is_inf(inp.b) & (p_sign != b_sign)): # inf - inf m.d.comb += [ + Signal(name="case_inf_minus_inf", attrs=keep).eq(True), out.bypassed_z.eq(fpf.quiet_nan()), out.do_bypass.eq(True), ] with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)): # inf + x m.d.comb += [ + Signal(name="case_inf_plus_x", attrs=keep).eq(True), out.bypassed_z.eq(fpf.inf(p_sign)), out.do_bypass.eq(True), ] with m.Elif(fpf.is_inf(inp.b)): # x + inf m.d.comb += [ + Signal(name="case_x_plus_inf", attrs=keep).eq(True), out.bypassed_z.eq(fpf.inf(b_sign)), out.do_bypass.eq(True), ] with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c)) - & fpf.is_zero(inp.b) & p_sign == b_sign): + & fpf.is_zero(inp.b) & (p_sign == b_sign)): # zero + zero m.d.comb += [ + Signal(name="case_zero_plus_zero", attrs=keep).eq(True), out.bypassed_z.eq(fpf.zero(p_sign)), out.do_bypass.eq(True), ] - # zero - zero handled by FPFMAMainStage + with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c)) + & ~fpf.is_zero(inp.b)): + # zero + x + m.d.comb += [ + Signal(name="case_zero_plus_x", attrs=keep).eq(True), + out.bypassed_z.eq(inp.b), + out.do_bypass.eq(True), + ] with m.Else(): + # zero - zero handled by FPFMAMainStage m.d.comb += [ out.bypassed_z.eq(0), out.do_bypass.eq(False), @@ -229,3 +266,13 @@ class FPFMASpecialCasesDeNorm(PipeModBase): ] return m + + +class FPFMASpecialCasesDeNormStage(PipeModBaseChain): + def __init__(self, pspec): + super().__init__(pspec) + + def get_chain(self): + """ gets chain of modules + """ + return [FPFMASpecialCasesDeNorm(self.pspec)]