From 449176c8896dd13ae80130a3a0c8fc88026a2499 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 4 Jul 2022 03:49:45 -0700 Subject: [PATCH] working on implementing fma, f16 rtz formal proof seems likely to work bitwuzla has been running the formal proof for the last 23min, seems like it'll probably succeed after a bunch of time. --- src/ieee754/fpfma/main_stage.py | 46 +- src/ieee754/fpfma/norm.py | 44 +- src/ieee754/fpfma/pipeline.py | 4 +- src/ieee754/fpfma/special_cases.py | 79 ++- src/ieee754/fpfma/test/__init__.py | 0 src/ieee754/fpfma/test/test_fma_formal.py | 559 ++++++++++++++++++++++ src/ieee754/fpfma/util.py | 41 +- 7 files changed, 736 insertions(+), 37 deletions(-) create mode 100644 src/ieee754/fpfma/test/__init__.py create mode 100644 src/ieee754/fpfma/test/test_fma_formal.py diff --git a/src/ieee754/fpfma/main_stage.py b/src/ieee754/fpfma/main_stage.py index 1ab2b2b8..7a028107 100644 --- a/src/ieee754/fpfma/main_stage.py +++ b/src/ieee754/fpfma/main_stage.py @@ -3,13 +3,13 @@ computes `z = (a * c) + b` but only rounds once at the end """ -from nmutil.pipemodbase import PipeModBase +from nmutil.pipemodbase import PipeModBase, PipeModBaseChain from ieee754.fpcommon.fpbase import FPRoundingMode from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData from nmigen.hdl.dsl import Module -from nmigen.hdl.ast import Signal, signed, unsigned, Mux +from nmigen.hdl.ast import Signal, signed, unsigned, Mux, Cat from ieee754.fpfma.util import expanded_exponent_shape, \ - expanded_mantissa_shape, get_fpformat + expanded_mantissa_shape, get_fpformat, EXPANDED_MANTISSA_EXTRA_LSBS from ieee754.fpcommon.getop import FPPipeContext @@ -38,8 +38,31 @@ class FPFMAPostCalcData: self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) """rounding mode""" + def eq(self, i): + return [ + self.sign.eq(i.sign), + self.exponent.eq(i.exponent), + self.mantissa.eq(i.mantissa), + self.bypassed_z.eq(i.bypassed_z), + self.do_bypass.eq(i.do_bypass), + self.ctx.eq(i.ctx), + self.rm.eq(i.rm), + ] + + def __iter__(self): + yield self.sign + yield self.exponent + yield self.mantissa + yield self.bypassed_z + yield self.do_bypass + yield self.ctx + yield self.rm + + def ports(self): + return list(self) + -class FPFMAMainStage(PipeModBase): +class FPFMAMain(PipeModBase): def __init__(self, pspec): super().__init__(pspec, "main") @@ -65,8 +88,9 @@ class FPFMAMainStage(PipeModBase): negate_b_s.eq(inp.do_sub), negate_b_u.eq(inp.do_sub), ] - sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u - sum = Signal(sum_v.shape()) + sum_v = (product_v << EXPANDED_MANTISSA_EXTRA_LSBS) + \ + (inp.b_mantissa ^ negate_b_s) + negate_b_u + sum = Signal(expanded_mantissa_shape(fpf)) m.d.comb += sum.eq(sum_v) sum_neg = Signal() @@ -97,3 +121,13 @@ class FPFMAMainStage(PipeModBase): out.rm.eq(inp.rm), ] return m + + +class FPFMAMainStage(PipeModBaseChain): + def __init__(self, pspec): + super().__init__(pspec) + + def get_chain(self): + """ gets chain of modules + """ + return [FPFMAMain(self.pspec)] diff --git a/src/ieee754/fpfma/norm.py b/src/ieee754/fpfma/norm.py index 0c16f452..21022c81 100644 --- a/src/ieee754/fpfma/norm.py +++ b/src/ieee754/fpfma/norm.py @@ -1,12 +1,14 @@ from nmutil.pipemodbase import PipeModBaseChain, PipeModBase +from ieee754.fpcommon.fpbase import OverflowMod from ieee754.fpcommon.postnormalise import FPNorm1Data from ieee754.fpcommon.roundz import FPRoundMod from ieee754.fpcommon.corrections import FPCorrectionsMod from ieee754.fpcommon.pack import FPPackMod from ieee754.fpfma.main_stage import FPFMAPostCalcData from nmigen.hdl.dsl import Module - +from nmigen.hdl.ast import Signal from ieee754.fpfma.util import get_fpformat +from nmigen.lib.coding import PriorityEncoder class FPFMANorm(PipeModBase): @@ -23,16 +25,38 @@ class FPFMANorm(PipeModBase): m = Module() fpf = get_fpformat(self.pspec) assert fpf.has_sign - inp = self.i - out = self.o - raise NotImplementedError # FIXME: finish + inp: FPFMAPostCalcData = self.i + out: FPNorm1Data = self.o + m.submodules.pri_enc = pri_enc = PriorityEncoder(inp.mantissa.width) + m.d.comb += pri_enc.i.eq(inp.mantissa[::-1]) + unrestricted_shift_amount = Signal(range(inp.mantissa.width)) + shift_amount = Signal(range(inp.mantissa.width)) + m.d.comb += unrestricted_shift_amount.eq(pri_enc.o) + with m.If(inp.exponent - (1 + fpf.e_sub) < unrestricted_shift_amount): + m.d.comb += shift_amount.eq(inp.exponent - (1 + fpf.e_sub)) + with m.Else(): + m.d.comb += shift_amount.eq(unrestricted_shift_amount) + n_mantissa = Signal(inp.mantissa.width) + m.d.comb += n_mantissa.eq(inp.mantissa << shift_amount) + + m.submodules.of = of = OverflowMod() m.d.comb += [ - out.roundz.eq(), - out.z.eq(), - out.out_do_z.eq(), - out.oz.eq(), - out.ctx.eq(), - out.rm.eq(), + pri_enc.i.eq(inp.mantissa[::-1]), + of.guard.eq(n_mantissa[-(out.z.m.width + 1)]), + of.round_bit.eq(n_mantissa[-(out.z.m.width + 2)]), + of.sticky.eq(n_mantissa[:-(out.z.m.width + 2)].bool()), + of.m0.eq(out.z.m[0]), + of.fpflags.eq(0), + of.sign.eq(inp.sign), + of.rm.eq(inp.rm), + out.roundz.eq(of.roundz_out), + out.z.s.eq(inp.sign), + out.z.e.eq(inp.exponent - shift_amount), + out.z.m.eq(n_mantissa[-out.z.m.width:]), + out.out_do_z.eq(inp.do_bypass), + out.oz.eq(inp.bypassed_z), + out.ctx.eq(inp.ctx), + out.rm.eq(inp.rm), ] return m diff --git a/src/ieee754/fpfma/pipeline.py b/src/ieee754/fpfma/pipeline.py index f0b928d0..3661d3c4 100644 --- a/src/ieee754/fpfma/pipeline.py +++ b/src/ieee754/fpfma/pipeline.py @@ -4,7 +4,7 @@ computes `z = (a * c) + b` but only rounds once at the end """ from nmutil.singlepipe import ControlBase -from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNorm +from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormStage from ieee754.fpfma.main_stage import FPFMAMainStage from ieee754.fpfma.norm import FPFMANormToPack @@ -12,7 +12,7 @@ from ieee754.fpfma.norm import FPFMANormToPack class FPFMABasePipe(ControlBase): def __init__(self, pspec): super().__init__() - self.sc_denorm = FPFMASpecialCasesDeNorm(pspec) + self.sc_denorm = FPFMASpecialCasesDeNormStage(pspec) self.main = FPFMAMainStage(pspec) self.normpack = FPFMANormToPack(pspec) self._eqs = self.connect([self.sc_denorm, self.main, self.normpack]) diff --git a/src/ieee754/fpfma/special_cases.py b/src/ieee754/fpfma/special_cases.py index 95d30266..826c32a8 100644 --- a/src/ieee754/fpfma/special_cases.py +++ b/src/ieee754/fpfma/special_cases.py @@ -3,14 +3,16 @@ computes `z = (a * c) + b` but only rounds once at the end """ -from nmutil.pipemodbase import PipeModBase +from nmutil.pipemodbase import PipeModBase, PipeModBaseChain from ieee754.fpcommon.basedata import FPBaseData from nmigen.hdl.ast import Signal from nmigen.hdl.dsl import Module from ieee754.fpcommon.getop import FPPipeContext from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge from ieee754.fpfma.util import expanded_exponent_shape, \ - expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape + expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape, \ + EXPANDED_MANTISSA_EXTRA_MSBS, EXPANDED_MANTISSA_EXTRA_LSBS, \ + product_mantissa_shape class FPFMAInputData(FPBaseData): @@ -52,13 +54,13 @@ class FPFMASpecialCasesDeNormOutData: self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf)) """mantissa of a input -- un-normalized and with implicit bit added""" - self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf)) + self.b_mantissa = Signal(expanded_mantissa_shape(fpf)) """mantissa of b input shifted to appropriate location for add and with implicit bit added """ - self.c_mantissa = Signal(expanded_mantissa_shape(fpf)) + self.c_mantissa = Signal(multiplicand_mantissa_shape(fpf)) """mantissa of c input -- un-normalized and with implicit bit added""" self.do_sub = Signal() @@ -123,15 +125,30 @@ class FPFMASpecialCasesDeNorm(PipeModBase): out = self.o a_exponent = Signal(expanded_exponent_shape(fpf)) - m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a)) + m.d.comb += a_exponent.eq(fpf.get_exponent_value(inp.a)) b_exponent_in = Signal(expanded_exponent_shape(fpf)) - m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b)) + m.d.comb += b_exponent_in.eq(fpf.get_exponent_value(inp.b)) c_exponent = Signal(expanded_exponent_shape(fpf)) - m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c)) + m.d.comb += c_exponent.eq(fpf.get_exponent_value(inp.c)) + b_exponent = Signal(expanded_exponent_shape(fpf)) + m.d.comb += b_exponent.eq(b_exponent_in + EXPANDED_MANTISSA_EXTRA_MSBS) prod_exponent = Signal(expanded_exponent_shape(fpf)) - m.d.comb += prod_exponent.eq(a_exponent + c_exponent) + + # number of bits that the product of two normalized signals needs to + # be shifted left to be normalized, e.g. the product of 2 8-bit + # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted + # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again: + # `0x4000 << 1 == 0x8000` + PROD_STAY_NORM_SHIFT = 1 + + extra_prod_exponent = (expanded_mantissa_shape(fpf).width + - product_mantissa_shape(fpf).width + + PROD_STAY_NORM_SHIFT + - EXPANDED_MANTISSA_EXTRA_LSBS) + m.d.comb += prod_exponent.eq(a_exponent + c_exponent + + extra_prod_exponent) prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf)) - m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in) + m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent) b_mantissa_in = Signal(fpf.fraction_width + 1) m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b)) p_sign = Signal() @@ -150,30 +167,37 @@ class FPFMASpecialCasesDeNorm(PipeModBase): ] with m.Else(): m.d.comb += [ - exponent.eq(b_exponent_in), + exponent.eq(b_exponent), b_shift.eq(0), ] - m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width) + m.submodules.rshiftm = rshiftm = MultiShiftRMerge( + out.b_mantissa.width - EXPANDED_MANTISSA_EXTRA_MSBS, + s_max=expanded_exponent_shape(fpf).width - 1) m.d.comb += [ - rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width - - b_mantissa_in.width)), + rshiftm.inp.eq(0), + rshiftm.inp[-b_mantissa_in.width:].eq(b_mantissa_in), rshiftm.diff.eq(b_shift), ] + keep = {"keep": True} + # handle special cases with m.If(fpf.is_nan(inp.a)): m.d.comb += [ + Signal(name="case_nan_a", attrs=keep).eq(True), out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)), out.do_bypass.eq(True), ] with m.Elif(fpf.is_nan(inp.b)): m.d.comb += [ + Signal(name="case_nan_b", attrs=keep).eq(True), out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)), out.do_bypass.eq(True), ] with m.Elif(fpf.is_nan(inp.c)): m.d.comb += [ + Signal(name="case_nan_c", attrs=keep).eq(True), out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)), out.do_bypass.eq(True), ] @@ -181,37 +205,50 @@ class FPFMASpecialCasesDeNorm(PipeModBase): | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))): # infinity * 0 m.d.comb += [ + Signal(name="case_inf_times_zero", attrs=keep).eq(True), out.bypassed_z.eq(fpf.quiet_nan()), out.do_bypass.eq(True), ] with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c)) - & fpf.is_inf(inp.b) & p_sign != b_sign): + & fpf.is_inf(inp.b) & (p_sign != b_sign)): # inf - inf m.d.comb += [ + Signal(name="case_inf_minus_inf", attrs=keep).eq(True), out.bypassed_z.eq(fpf.quiet_nan()), out.do_bypass.eq(True), ] with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)): # inf + x m.d.comb += [ + Signal(name="case_inf_plus_x", attrs=keep).eq(True), out.bypassed_z.eq(fpf.inf(p_sign)), out.do_bypass.eq(True), ] with m.Elif(fpf.is_inf(inp.b)): # x + inf m.d.comb += [ + Signal(name="case_x_plus_inf", attrs=keep).eq(True), out.bypassed_z.eq(fpf.inf(b_sign)), out.do_bypass.eq(True), ] with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c)) - & fpf.is_zero(inp.b) & p_sign == b_sign): + & fpf.is_zero(inp.b) & (p_sign == b_sign)): # zero + zero m.d.comb += [ + Signal(name="case_zero_plus_zero", attrs=keep).eq(True), out.bypassed_z.eq(fpf.zero(p_sign)), out.do_bypass.eq(True), ] - # zero - zero handled by FPFMAMainStage + with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c)) + & ~fpf.is_zero(inp.b)): + # zero + x + m.d.comb += [ + Signal(name="case_zero_plus_x", attrs=keep).eq(True), + out.bypassed_z.eq(inp.b), + out.do_bypass.eq(True), + ] with m.Else(): + # zero - zero handled by FPFMAMainStage m.d.comb += [ out.bypassed_z.eq(0), out.do_bypass.eq(False), @@ -229,3 +266,13 @@ class FPFMASpecialCasesDeNorm(PipeModBase): ] return m + + +class FPFMASpecialCasesDeNormStage(PipeModBaseChain): + def __init__(self, pspec): + super().__init__(pspec) + + def get_chain(self): + """ gets chain of modules + """ + return [FPFMASpecialCasesDeNorm(self.pspec)] diff --git a/src/ieee754/fpfma/test/__init__.py b/src/ieee754/fpfma/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ieee754/fpfma/test/test_fma_formal.py b/src/ieee754/fpfma/test/test_fma_formal.py new file mode 100644 index 00000000..7cea7b8e --- /dev/null +++ b/src/ieee754/fpfma/test/test_fma_formal.py @@ -0,0 +1,559 @@ +import unittest +from nmutil.formaltest import FHDLTestCase +from ieee754.fpfma.pipeline import FPFMABasePipe +from nmigen.hdl.dsl import Module +from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume, Mux +from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \ + SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, SmtBool, \ + SmtRoundingMode, ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE, SmtBitVec +from ieee754.fpcommon.fpbase import FPFormat, FPRoundingMode +from ieee754.pipeline import PipelineSpec +import os + +ENABLE_FMA_F32_FORMAL = os.getenv("ENABLE_FMA_F32_FORMAL") is not None + + +class TestFMAFormal(FHDLTestCase): + @unittest.skip("not finished implementing") # FIXME: remove skip + def tst_fma_formal(self, sort, rm, negate_addend, negate_product): + assert isinstance(sort, SmtSortFloatingPoint) + assert isinstance(rm, FPRoundingMode) + assert isinstance(negate_addend, bool) + assert isinstance(negate_product, bool) + width = sort.width + pspec = PipelineSpec(width, id_width=4, n_ops=3) + pspec.fpformat = FPFormat(e_width=sort.eb, + m_width=sort.mantissa_field_width) + dut = FPFMABasePipe(pspec) + m = Module() + m.submodules.dut = dut + m.d.comb += dut.n.i_ready.eq(True) + m.d.comb += dut.p.i_valid.eq(Initial()) + m.d.comb += dut.p.i_data.rm.eq(Mux(Initial(), rm, 0)) + out = Signal(width) + out_full = Signal(reset=False) + with m.If(dut.n.trigger): + # check we only got output for one cycle + m.d.comb += Assert(~out_full) + m.d.sync += out.eq(dut.n.o_data.z) + m.d.sync += out_full.eq(True) + a = Signal(width) + b = Signal(width) + c = Signal(width) + with m.If(Initial() | True): # FIXME: remove | True + m.d.comb += [ + dut.p.i_data.a.eq(a), + dut.p.i_data.b.eq(b), + dut.p.i_data.c.eq(c), + dut.p.i_data.negate_addend.eq(negate_addend), + dut.p.i_data.negate_product.eq(negate_product), + ] + + def smt_op(a_fp, b_fp, c_fp, rm): + assert isinstance(a_fp, SmtFloatingPoint) + assert isinstance(b_fp, SmtFloatingPoint) + assert isinstance(c_fp, SmtFloatingPoint) + assert isinstance(rm, SmtRoundingMode) + if negate_addend: + b_fp = -b_fp + if negate_product: + a_fp = -a_fp + return a_fp.fma(c_fp, b_fp, rm=rm) + a_fp = SmtFloatingPoint.from_bits(a, sort=sort) + b_fp = SmtFloatingPoint.from_bits(b, sort=sort) + c_fp = SmtFloatingPoint.from_bits(c, sort=sort) + out_fp = SmtFloatingPoint.from_bits(out, sort=sort) + if rm in (FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE, + FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE): + rounded_up = Signal(width) + m.d.comb += rounded_up.eq(AnyConst(width)) + rounded_up_fp = smt_op(a_fp, b_fp, c_fp, rm=ROUND_TOWARD_POSITIVE) + rounded_down_fp = smt_op(a_fp, b_fp, c_fp, + rm=ROUND_TOWARD_NEGATIVE) + m.d.comb += Assume(SmtFloatingPoint.from_bits( + rounded_up, sort=sort).same(rounded_up_fp).as_value()) + use_rounded_up = SmtBool.make(rounded_up[0]) + if rm is FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE: + is_zero = rounded_up_fp.is_zero() & rounded_down_fp.is_zero() + use_rounded_up |= is_zero + expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp) + else: + smt_rm = SmtRoundingMode.make(rm.to_smtlib2()) + expected_fp = smt_op(a_fp, b_fp, c_fp, rm=smt_rm) + expected = Signal(width) + m.d.comb += expected.eq(AnyConst(width)) + quiet_bit = 1 << (sort.mantissa_field_width - 1) + nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width + with m.If(expected_fp.is_nan().as_value()): + with m.If(a_fp.is_nan().as_value()): + m.d.comb += Assume(expected == (a | quiet_bit)) + with m.Elif(b_fp.is_nan().as_value()): + m.d.comb += Assume(expected == (b | quiet_bit)) + with m.Elif(c_fp.is_nan().as_value()): + m.d.comb += Assume(expected == (c | quiet_bit)) + with m.Else(): + m.d.comb += Assume(expected == (nan_exponent | quiet_bit)) + with m.Else(): + m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort) + .same(expected_fp).as_value()) + m.d.comb += a.eq(AnyConst(width)) + m.d.comb += b.eq(AnyConst(width)) + m.d.comb += c.eq(AnyConst(width)) + with m.If(out_full): + m.d.comb += Assert(out_fp.same(expected_fp).as_value()) + m.d.comb += Assert(out == expected) + + def fp_from_int(v): + return SmtFloatingPoint.from_signed_bv( + SmtBitVec.make(v, width=128), + rm=ROUND_TOWARD_POSITIVE, sort=sort) + + # FIXME: remove: + if False: + m.d.comb += Assume(a == 0x05C1) + m.d.comb += Assume(b == 0x877F) + m.d.comb += Assume(c == 0x7437) + with m.If(out_full): + m.d.comb += Assert(out == 0x0000) + m.d.comb += Assert(out == 0x0001) + + self.assertFormal(m, depth=5, solver="bitwuzla") + + # FIXME: check exception flags + + def test_fmadd_f16_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE, + negate_addend=False, negate_product=True) + + def test_fmadd_f16_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ, + negate_addend=False, negate_product=True) + + def test_fmadd_f16_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP, + negate_addend=False, negate_product=True) + + def test_fmadd_f16_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN, + negate_addend=False, negate_product=True) + + def test_fmadd_f16_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA, + negate_addend=False, negate_product=True) + + def test_fmadd_f16_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP, + negate_addend=False, negate_product=True) + + def test_fmadd_f16_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON, + negate_addend=False, negate_product=False) + + def test_fmsub_f16_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON, + negate_addend=True, negate_product=False) + + def test_fnmadd_f16_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON, + negate_addend=True, negate_product=True) + + def test_fnmsub_f16_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP, + negate_addend=False, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmadd_f32_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON, + negate_addend=False, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fmsub_f32_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON, + negate_addend=True, negate_product=False) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmadd_f32_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON, + negate_addend=True, negate_product=True) + + @unittest.skipUnless(ENABLE_FMA_F32_FORMAL, + "ENABLE_FMA_F32_FORMAL not in environ") + def test_fnmsub_f32_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rne_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rtz_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rtp_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rtn_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rna_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rtop_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP, + negate_addend=False, negate_product=True) + + @unittest.skip("too slow") + def test_fmadd_f64_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON, + negate_addend=False, negate_product=False) + + @unittest.skip("too slow") + def test_fmsub_f64_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON, + negate_addend=True, negate_product=False) + + @unittest.skip("too slow") + def test_fnmadd_f64_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON, + negate_addend=True, negate_product=True) + + @unittest.skip("too slow") + def test_fnmsub_f64_rton_formal(self): + self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON, + negate_addend=False, negate_product=True) + + def test_all_rounding_modes_covered(self): + for width in 16, 32, 64: + for rm in FPRoundingMode: + rm_s = rm.name.lower() + name = f"test_fmadd_f{width}_{rm_s}_formal" + assert callable(getattr(self, name)) + name = f"test_fmsub_f{width}_{rm_s}_formal" + assert callable(getattr(self, name)) + name = f"test_fnmadd_f{width}_{rm_s}_formal" + assert callable(getattr(self, name)) + name = f"test_fnmsub_f{width}_{rm_s}_formal" + assert callable(getattr(self, name)) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/ieee754/fpfma/util.py b/src/ieee754/fpfma/util.py index 5372ab8f..518a5d9e 100644 --- a/src/ieee754/fpfma/util.py +++ b/src/ieee754/fpfma/util.py @@ -7,13 +7,48 @@ def expanded_exponent_shape(fpformat): return signed(fpformat.e_width + 3) -EXPANDED_MANTISSA_EXTRA_LSBS = 3 +EXPANDED_MANTISSA_SPACE_BETWEEN_SUM_PROD = 16 # FIXME: change back to 3 +r""" the number of bits of space between the lsb of a large addend and the msb +of the product of two small factors to guarantee that the product ends up +entirely in the sticky bit. + +e.g. let's assume the floating point format has +5 mantissa bits (4 bits in the field + 1 implicit bit): + +if `a` and `b` are `0b11111` and `c` is `0b11111 * 2**-50`, and we are +computing `a * c + b`: + +the computed mantissa would be: + +```text + sticky bit + | + v +0b111110001111000001 + \-b-/ \-product/ +``` + +(note this isn't the mathematically correct +answer, but it rounds to the correct floating-point answer and takes +less hardware) +""" + +# the number of extra LSBs needed by the expanded mantissa to avoid +# having a tiny addend conflict with the lsb of the product. +EXPANDED_MANTISSA_EXTRA_LSBS = 16 # FIXME: change back to 2 + + +# the number of extra MSBs needed by the expanded mantissa to avoid +# overflowing. 2 bits -- 1 bit for carry out of addition, 1 bit for sign. +EXPANDED_MANTISSA_EXTRA_MSBS = 16 # FIXME: change back to 2 def expanded_mantissa_shape(fpformat): assert isinstance(fpformat, FPFormat) - return signed(fpformat.fraction_width * 3 + - 2 + EXPANDED_MANTISSA_EXTRA_LSBS) + return signed((fpformat.fraction_width + 1) * 3 + + EXPANDED_MANTISSA_EXTRA_MSBS + + EXPANDED_MANTISSA_SPACE_BETWEEN_SUM_PROD + + EXPANDED_MANTISSA_EXTRA_LSBS) def multiplicand_mantissa_shape(fpformat): -- 2.30.2