From d502b5b9d5c5ddeb582ab6ad0c998b4f36b3103a Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 27 Jun 2022 22:27:39 -0700 Subject: [PATCH] add correct NaN propagation to the fadd pipeline and formal proof --- src/ieee754/fpadd/specialcases.py | 14 ++++++++---- src/ieee754/fpadd/test/test_add_formal.py | 26 +++++++++++++++-------- src/ieee754/fpcommon/fpbase.py | 6 ++++++ 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/ieee754/fpadd/specialcases.py b/src/ieee754/fpadd/specialcases.py index 656ace3c..3b6c6921 100644 --- a/src/ieee754/fpadd/specialcases.py +++ b/src/ieee754/fpadd/specialcases.py @@ -80,11 +80,15 @@ class FPAddSpecialCasesMod(PipeModBase): # prepare inf/zero/nans z_zero = FPNumBaseRecord(width, False, name="z_zero") - z_nan = FPNumBaseRecord(width, False, name="z_nan") + z_default_nan = FPNumBaseRecord(width, False, name="z_default_nan") + z_quieted_a = FPNumBaseRecord(width, False, name="z_quieted_a") + z_quieted_b = FPNumBaseRecord(width, False, name="z_quieted_b") z_infa = FPNumBaseRecord(width, False, name="z_infa") z_infb = FPNumBaseRecord(width, False, name="z_infb") comb += z_zero.zero(0) - comb += z_nan.nan(0) + comb += z_default_nan.nan(0) + comb += z_quieted_a.quieted_nan(a1) + comb += z_quieted_b.quieted_nan(b1) comb += z_infa.inf(a1.s) comb += z_infb.inf(b1.s) @@ -93,6 +97,8 @@ class FPAddSpecialCasesMod(PipeModBase): # this is the logic-decision-making for special-cases: # if a is NaN or b is NaN return NaN + # if a is NaN return quieted_nan(a) + # else return quieted_nan(b) # elif a is inf return inf (or NaN) # if a is inf and signs don't match return NaN # else return inf(a) @@ -112,8 +118,8 @@ class FPAddSpecialCasesMod(PipeModBase): oz = Mux(t_a1zero, b1.v, oz) oz = Mux(t_abz, Cat(self.i.b[:-1], absa), oz) oz = Mux(t_b1inf, z_infb.v, oz) - oz = Mux(t_a1inf, Mux(bexp128s, z_nan.v, z_infa.v), oz) - oz = Mux(t_abnan, z_nan.v, oz) + oz = Mux(t_a1inf, Mux(bexp128s, z_default_nan.v, z_infa.v), oz) + oz = Mux(t_abnan, Mux(a1.is_nan, z_quieted_a.v, z_quieted_b.v), oz) comb += self.o.oz.eq(oz) diff --git a/src/ieee754/fpadd/test/test_add_formal.py b/src/ieee754/fpadd/test/test_add_formal.py index 013aca57..95e842d0 100644 --- a/src/ieee754/fpadd/test/test_add_formal.py +++ b/src/ieee754/fpadd/test/test_add_formal.py @@ -2,7 +2,7 @@ import unittest from nmutil.formaltest import FHDLTestCase from ieee754.fpadd.pipeline import FPADDBasePipe from nmigen.hdl.dsl import Module -from nmigen.hdl.ast import AnySeq, Initial, Assert, AnyConst, Signal, Assume +from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \ SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, \ ROUND_NEAREST_TIES_TO_EVEN @@ -27,17 +27,25 @@ class TestFAddFormal(FHDLTestCase): z_fp = SmtFloatingPoint.from_bits(z, sort=sort) expected_fp = a_fp.add(b_fp, rm=rm) expected = Signal(width) - m.d.comb += expected.eq(AnySeq(width)) - # Important Note: expected and z won't necessarily match bit-exactly - # if it's a NaN, all this checks for is z is also any NaN - m.d.comb += Assume((SmtFloatingPoint.from_bits(expected, sort=sort) - == expected_fp).as_value()) - # FIXME: check that it produces the correct NaNs + m.d.comb += expected.eq(AnyConst(width)) + quiet_bit = 1 << (sort.mantissa_field_width - 1) + nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width + with m.If(expected_fp.is_nan().as_value()): + with m.If(a_fp.is_nan().as_value()): + m.d.comb += Assume(expected == (a | quiet_bit)) + with m.Elif(b_fp.is_nan().as_value()): + m.d.comb += Assume(expected == (b | quiet_bit)) + with m.Else(): + m.d.comb += Assume(expected == (nan_exponent | quiet_bit)) + with m.Else(): + m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort) + .same(expected_fp).as_value()) m.d.comb += a.eq(AnyConst(width)) m.d.comb += b.eq(AnyConst(width)) with m.If(dut.n.trigger): - m.d.sync += Assert((z_fp == expected_fp).as_value()) - self.assertFormal(m, depth=5, solver="z3") + m.d.sync += Assert(z_fp.same(expected_fp).as_value()) + m.d.sync += Assert(z == expected) + self.assertFormal(m, depth=5, solver="bitwuzla") # FIXME: check other rounding modes # FIXME: check exception flags diff --git a/src/ieee754/fpcommon/fpbase.py b/src/ieee754/fpcommon/fpbase.py index 30178633..6e40b021 100644 --- a/src/ieee754/fpcommon/fpbase.py +++ b/src/ieee754/fpcommon/fpbase.py @@ -425,6 +425,12 @@ class FPNumBaseRecord: def nan(self, s): return self.create(*self._nan(s)) + def quieted_nan(self, other): + assert isinstance(other, FPNumBaseRecord) + assert self.width == other.width + return self.create(other.s, self.fp.P128, + other.v[0:self.e_start] | (1 << (self.e_start - 1))) + def inf(self, s): return self.create(*self._inf(s)) -- 2.30.2