switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / fpadd / test / test_add_formal.py
1 import unittest
2 from nmutil.formaltest import FHDLTestCase
3 from ieee754.fpadd.pipeline import FPADDBasePipe
4 from nmigen.hdl.dsl import Module
5 from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume
6 from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \
7 SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, \
8 ROUND_NEAREST_TIES_TO_EVEN
9 from ieee754.pipeline import PipelineSpec
10
11
12 class TestFAddFormal(FHDLTestCase):
13 def tst_fadd_rne_formal(self, sort):
14 assert isinstance(sort, SmtSortFloatingPoint)
15 width = sort.width
16 dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
17 m = Module()
18 m.submodules.dut = dut
19 m.d.comb += dut.n.i_ready.eq(True)
20 m.d.comb += dut.p.i_valid.eq(Initial())
21 a = dut.p.i_data.a
22 b = dut.p.i_data.b
23 z = dut.n.o_data.z
24 rm = ROUND_NEAREST_TIES_TO_EVEN
25 a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
26 b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
27 z_fp = SmtFloatingPoint.from_bits(z, sort=sort)
28 expected_fp = a_fp.add(b_fp, rm=rm)
29 expected = Signal(width)
30 m.d.comb += expected.eq(AnyConst(width))
31 quiet_bit = 1 << (sort.mantissa_field_width - 1)
32 nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width
33 with m.If(expected_fp.is_nan().as_value()):
34 with m.If(a_fp.is_nan().as_value()):
35 m.d.comb += Assume(expected == (a | quiet_bit))
36 with m.Elif(b_fp.is_nan().as_value()):
37 m.d.comb += Assume(expected == (b | quiet_bit))
38 with m.Else():
39 m.d.comb += Assume(expected == (nan_exponent | quiet_bit))
40 with m.Else():
41 m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort)
42 .same(expected_fp).as_value())
43 m.d.comb += a.eq(AnyConst(width))
44 m.d.comb += b.eq(AnyConst(width))
45 with m.If(dut.n.trigger):
46 m.d.sync += Assert(z_fp.same(expected_fp).as_value())
47 m.d.sync += Assert(z == expected)
48 self.assertFormal(m, depth=5, solver="bitwuzla")
49
50 # FIXME: check other rounding modes
51 # FIXME: check exception flags
52
53 def test_fadd16_rne_formal(self):
54 self.tst_fadd_rne_formal(SmtSortFloat16())
55
56 def test_fadd32_rne_formal(self):
57 self.tst_fadd_rne_formal(SmtSortFloat32())
58
59 @unittest.skip("too slow")
60 def test_fadd64_rne_formal(self):
61 self.tst_fadd_rne_formal(SmtSortFloat64())
62
63
64
65 if __name__ == '__main__':
66 unittest.main()