From 5b37dceef588adc9f544c76da2ec2441f2aade7f Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 30 Jun 2022 22:07:27 -0700 Subject: [PATCH] add fsub support to fadd pipeline --- src/ieee754/fpadd/specialcases.py | 37 +++++-- src/ieee754/fpadd/test/test_add_formal.py | 128 +++++++++++++++++----- src/ieee754/fpcommon/fpbase.py | 32 +++--- 3 files changed, 148 insertions(+), 49 deletions(-) diff --git a/src/ieee754/fpadd/specialcases.py b/src/ieee754/fpadd/specialcases.py index 31a49068..8ae66e7e 100644 --- a/src/ieee754/fpadd/specialcases.py +++ b/src/ieee754/fpadd/specialcases.py @@ -5,13 +5,31 @@ from nmigen import Module, Signal, Cat, Mux from nmutil.pipemodbase import PipeModBase, PipeModBaseChain -from ieee754.fpcommon.fpbase import FPNumDecode, FPRoundingMode +from ieee754.fpcommon.fpbase import FPFormat, FPNumDecode, FPRoundingMode from ieee754.fpcommon.fpbase import FPNumBaseRecord from ieee754.fpcommon.basedata import FPBaseData from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod) +class FPAddInputData(FPBaseData): + def __init__(self, pspec): + super().__init__(pspec) + self.is_sub = Signal(reset=False) + + def eq(self, i): + ret = super().eq(i) + ret.append(self.is_sub.eq(i.is_sub)) + return ret + + def __iter__(self): + yield from super().__iter__() + yield self.is_sub + + def ports(self): + return list(self) + + class FPAddSpecialCasesMod(PipeModBase): """ special cases: NaNs, infs, zeros, denormalised NOTE: some of these are unique to add. see "Special Operations" @@ -22,7 +40,7 @@ class FPAddSpecialCasesMod(PipeModBase): super().__init__(pspec, "specialcases") def ispec(self): - return FPBaseData(self.pspec) + return FPAddInputData(self.pspec) def ospec(self): return FPSCData(self.pspec, True) @@ -37,11 +55,16 @@ class FPAddSpecialCasesMod(PipeModBase): b1 = FPNumBaseRecord(width) m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1) m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1) - comb += [a1.v.eq(self.i.a), - b1.v.eq(self.i.b), - self.o.a.eq(a1), - self.o.b.eq(b1) - ] + flip_b_sign = Signal() + b_is_nan = Signal() + comb += [ + b_is_nan.eq(FPFormat.standard(width).is_nan(self.i.b)), + flip_b_sign.eq(self.i.is_sub & ~b_is_nan), + a1.v.eq(self.i.a), + b1.v.eq(self.i.b ^ (flip_b_sign << (width - 1))), + self.o.a.eq(a1), + self.o.b.eq(b1) + ] zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign) diff --git a/src/ieee754/fpadd/test/test_add_formal.py b/src/ieee754/fpadd/test/test_add_formal.py index 915c2b94..95d04d17 100644 --- a/src/ieee754/fpadd/test/test_add_formal.py +++ b/src/ieee754/fpadd/test/test_add_formal.py @@ -10,10 +10,11 @@ from ieee754.fpcommon.fpbase import FPRoundingMode from ieee754.pipeline import PipelineSpec -class TestFAddFormal(FHDLTestCase): - def tst_fadd_formal(self, sort, rm): +class TestFAddFSubFormal(FHDLTestCase): + def tst_fadd_fsub_formal(self, sort, rm, is_sub): assert isinstance(sort, SmtSortFloatingPoint) assert isinstance(rm, FPRoundingMode) + assert isinstance(is_sub, bool) width = sort.width dut = FPADDBasePipe(PipelineSpec(width, id_width=4)) m = Module() @@ -32,6 +33,9 @@ class TestFAddFormal(FHDLTestCase): b = Signal(width) m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0)) m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0)) + m.d.comb += dut.p.i_data.is_sub.eq(Mux(Initial(), is_sub, 0)) + + smt_add_sub = SmtFloatingPoint.sub if is_sub else SmtFloatingPoint.add a_fp = SmtFloatingPoint.from_bits(a, sort=sort) b_fp = SmtFloatingPoint.from_bits(b, sort=sort) out_fp = SmtFloatingPoint.from_bits(out, sort=sort) @@ -39,8 +43,8 @@ class TestFAddFormal(FHDLTestCase): FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE): rounded_up = Signal(width) m.d.comb += rounded_up.eq(AnyConst(width)) - rounded_up_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_POSITIVE) - rounded_down_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_NEGATIVE) + rounded_up_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_POSITIVE) + rounded_down_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_NEGATIVE) m.d.comb += Assume(SmtFloatingPoint.from_bits( rounded_up, sort=sort).same(rounded_up_fp).as_value()) use_rounded_up = SmtBool.make(rounded_up[0]) @@ -50,7 +54,7 @@ class TestFAddFormal(FHDLTestCase): expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp) else: smt_rm = SmtRoundingMode.make(rm.to_smtlib2()) - expected_fp = a_fp.add(b_fp, rm=smt_rm) + expected_fp = smt_add_sub(a_fp, b_fp, rm=smt_rm) expected = Signal(width) m.d.comb += expected.eq(AnyConst(width)) quiet_bit = 1 << (sort.mantissa_field_width - 1) @@ -75,74 +79,144 @@ class TestFAddFormal(FHDLTestCase): # FIXME: check exception flags def test_fadd_f16_rne_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNE) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, False) def test_fadd_f32_rne_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNE) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, False) @unittest.skip("too slow") def test_fadd_f64_rne_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNE) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, False) def test_fadd_f16_rtz_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTZ) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, False) def test_fadd_f32_rtz_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTZ) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, False) @unittest.skip("too slow") def test_fadd_f64_rtz_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTZ) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, False) def test_fadd_f16_rtp_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTP) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, False) def test_fadd_f32_rtp_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTP) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, False) @unittest.skip("too slow") def test_fadd_f64_rtp_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTP) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, False) def test_fadd_f16_rtn_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTN) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, False) def test_fadd_f32_rtn_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTN) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, False) @unittest.skip("too slow") def test_fadd_f64_rtn_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTN) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, False) def test_fadd_f16_rna_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNA) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, False) def test_fadd_f32_rna_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNA) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, False) @unittest.skip("too slow") def test_fadd_f64_rna_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNA) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, False) def test_fadd_f16_rtop_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTOP) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, False) def test_fadd_f32_rtop_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTOP) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, False) @unittest.skip("too slow") def test_fadd_f64_rtop_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTOP) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, False) def test_fadd_f16_rton_formal(self): - self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTON) + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, False) def test_fadd_f32_rton_formal(self): - self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTON) + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, False) @unittest.skip("too slow") def test_fadd_f64_rton_formal(self): - self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTON) + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, False) + + def test_fsub_f16_rne_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, True) + + def test_fsub_f32_rne_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, True) + + @unittest.skip("too slow") + def test_fsub_f64_rne_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, True) + + def test_fsub_f16_rtz_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, True) + + def test_fsub_f32_rtz_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, True) + + @unittest.skip("too slow") + def test_fsub_f64_rtz_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, True) + + def test_fsub_f16_rtp_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, True) + + def test_fsub_f32_rtp_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, True) + + @unittest.skip("too slow") + def test_fsub_f64_rtp_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, True) + + def test_fsub_f16_rtn_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, True) + + def test_fsub_f32_rtn_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, True) + + @unittest.skip("too slow") + def test_fsub_f64_rtn_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, True) + + def test_fsub_f16_rna_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, True) + + def test_fsub_f32_rna_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, True) + + @unittest.skip("too slow") + def test_fsub_f64_rna_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, True) + + def test_fsub_f16_rtop_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, True) + + def test_fsub_f32_rtop_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, True) + + @unittest.skip("too slow") + def test_fsub_f64_rtop_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, True) + + def test_fsub_f16_rton_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, True) + + def test_fsub_f32_rton_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, True) + + @unittest.skip("too slow") + def test_fsub_f64_rton_formal(self): + self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, True) def test_all_rounding_modes_covered(self): for width in 16, 32, 64: @@ -150,6 +224,8 @@ class TestFAddFormal(FHDLTestCase): rm_s = rm.name.lower() name = f"test_fadd_f{width}_{rm_s}_formal" assert callable(getattr(self, name)) + name = f"test_fsub_f{width}_{rm_s}_formal" + assert callable(getattr(self, name)) if __name__ == '__main__': diff --git a/src/ieee754/fpcommon/fpbase.py b/src/ieee754/fpcommon/fpbase.py index fee88bef..fd707bc1 100644 --- a/src/ieee754/fpcommon/fpbase.py +++ b/src/ieee754/fpcommon/fpbase.py @@ -322,42 +322,42 @@ class FPFormat: def is_zero(self, x): """ returns true if x is +/- zero """ - return (self.get_exponent(x) == self.e_sub and - self.get_mantissa_field(x) == 0) + return (self.get_exponent(x) == self.e_sub) & \ + (self.get_mantissa_field(x) == 0) def is_subnormal(self, x): """ returns true if x is subnormal (exp at minimum) """ - return (self.get_exponent(x) == self.e_sub and - self.get_mantissa_field(x) != 0) + return (self.get_exponent(x) == self.e_sub) & \ + (self.get_mantissa_field(x) != 0) def is_inf(self, x): """ returns true if x is infinite """ - return (self.get_exponent(x) == self.e_max and - self.get_mantissa_field(x) == 0) + return (self.get_exponent(x) == self.e_max) & \ + (self.get_mantissa_field(x) == 0) def is_nan(self, x): """ returns true if x is a nan (quiet or signalling) """ - return (self.get_exponent(x) == self.e_max and - self.get_mantissa_field(x) != 0) + return (self.get_exponent(x) == self.e_max) & \ + (self.get_mantissa_field(x) != 0) def is_quiet_nan(self, x): """ returns true if x is a quiet nan """ - highbit = 1<<(self.m_width-1) - return (self.get_exponent(x) == self.e_max and - self.get_mantissa_field(x) != 0 and - self.get_mantissa_field(x) & highbit != 0) + highbit = 1 << (self.m_width - 1) + return (self.get_exponent(x) == self.e_max) & \ + (self.get_mantissa_field(x) != 0) & \ + (self.get_mantissa_field(x) & highbit != 0) def is_nan_signaling(self, x): """ returns true if x is a signalling nan """ - highbit = 1<<(self.m_width-1) - return ((self.get_exponent(x) == self.e_max) and - (self.get_mantissa_field(x) != 0) and - (self.get_mantissa_field(x) & highbit) == 0) + highbit = 1 << (self.m_width - 1) + return (self.get_exponent(x) == self.e_max) & \ + (self.get_mantissa_field(x) != 0) & \ + (self.get_mantissa_field(x) & highbit) == 0 @property def width(self): -- 2.30.2