2 from nmutil
.formaltest
import FHDLTestCase
3 from ieee754
.fpadd
.pipeline
import FPADDBasePipe
4 from nmigen
.hdl
.dsl
import Module
5 from nmigen
.hdl
.ast
import Initial
, Assert
, AnyConst
, Signal
, Assume
6 from nmigen
.hdl
.smtlib2
import SmtFloatingPoint
, SmtSortFloatingPoint
, \
7 SmtSortFloat16
, SmtSortFloat32
, SmtSortFloat64
, \
8 ROUND_NEAREST_TIES_TO_EVEN
9 from ieee754
.pipeline
import PipelineSpec
12 class TestFAddFormal(FHDLTestCase
):
13 def tst_fadd_rne_formal(self
, sort
):
14 assert isinstance(sort
, SmtSortFloatingPoint
)
16 dut
= FPADDBasePipe(PipelineSpec(width
, id_width
=4))
18 m
.submodules
.dut
= dut
19 m
.d
.comb
+= dut
.n
.i_ready
.eq(True)
20 m
.d
.comb
+= dut
.p
.i_valid
.eq(Initial())
24 rm
= ROUND_NEAREST_TIES_TO_EVEN
25 a_fp
= SmtFloatingPoint
.from_bits(a
, sort
=sort
)
26 b_fp
= SmtFloatingPoint
.from_bits(b
, sort
=sort
)
27 z_fp
= SmtFloatingPoint
.from_bits(z
, sort
=sort
)
28 expected_fp
= a_fp
.add(b_fp
, rm
=rm
)
29 expected
= Signal(width
)
30 m
.d
.comb
+= expected
.eq(AnyConst(width
))
31 quiet_bit
= 1 << (sort
.mantissa_field_width
- 1)
32 nan_exponent
= ((1 << sort
.eb
) - 1) << sort
.mantissa_field_width
33 with m
.If(expected_fp
.is_nan().as_value()):
34 with m
.If(a_fp
.is_nan().as_value()):
35 m
.d
.comb
+= Assume(expected
== (a | quiet_bit
))
36 with m
.Elif(b_fp
.is_nan().as_value()):
37 m
.d
.comb
+= Assume(expected
== (b | quiet_bit
))
39 m
.d
.comb
+= Assume(expected
== (nan_exponent | quiet_bit
))
41 m
.d
.comb
+= Assume(SmtFloatingPoint
.from_bits(expected
, sort
=sort
)
42 .same(expected_fp
).as_value())
43 m
.d
.comb
+= a
.eq(AnyConst(width
))
44 m
.d
.comb
+= b
.eq(AnyConst(width
))
45 with m
.If(dut
.n
.trigger
):
46 m
.d
.sync
+= Assert(z_fp
.same(expected_fp
).as_value())
47 m
.d
.sync
+= Assert(z
== expected
)
48 self
.assertFormal(m
, depth
=5, solver
="bitwuzla")
50 # FIXME: check other rounding modes
51 # FIXME: check exception flags
53 def test_fadd16_rne_formal(self
):
54 self
.tst_fadd_rne_formal(SmtSortFloat16())
56 def test_fadd32_rne_formal(self
):
57 self
.tst_fadd_rne_formal(SmtSortFloat32())
59 @unittest.skip("too slow")
60 def test_fadd64_rne_formal(self
):
61 self
.tst_fadd_rne_formal(SmtSortFloat64())
65 if __name__
== '__main__':