add fsub support to fadd pipeline
[ieee754fpu.git] / src / ieee754 / fpadd / test / test_add_formal.py
index 915c2b94312a8fafc6646cb0ec7ac369055d60cf..95d04d1705675518ab478f59ce5dce9f8d4299e6 100644 (file)
@@ -10,10 +10,11 @@ from ieee754.fpcommon.fpbase import FPRoundingMode
 from ieee754.pipeline import PipelineSpec
 
 
-class TestFAddFormal(FHDLTestCase):
-    def tst_fadd_formal(self, sort, rm):
+class TestFAddFSubFormal(FHDLTestCase):
+    def tst_fadd_fsub_formal(self, sort, rm, is_sub):
         assert isinstance(sort, SmtSortFloatingPoint)
         assert isinstance(rm, FPRoundingMode)
+        assert isinstance(is_sub, bool)
         width = sort.width
         dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
         m = Module()
@@ -32,6 +33,9 @@ class TestFAddFormal(FHDLTestCase):
         b = Signal(width)
         m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0))
         m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0))
+        m.d.comb += dut.p.i_data.is_sub.eq(Mux(Initial(), is_sub, 0))
+
+        smt_add_sub = SmtFloatingPoint.sub if is_sub else SmtFloatingPoint.add
         a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
         b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
         out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
@@ -39,8 +43,8 @@ class TestFAddFormal(FHDLTestCase):
                   FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
             rounded_up = Signal(width)
             m.d.comb += rounded_up.eq(AnyConst(width))
-            rounded_up_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_POSITIVE)
-            rounded_down_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_NEGATIVE)
+            rounded_up_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_POSITIVE)
+            rounded_down_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_NEGATIVE)
             m.d.comb += Assume(SmtFloatingPoint.from_bits(
                 rounded_up, sort=sort).same(rounded_up_fp).as_value())
             use_rounded_up = SmtBool.make(rounded_up[0])
@@ -50,7 +54,7 @@ class TestFAddFormal(FHDLTestCase):
             expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
         else:
             smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
-            expected_fp = a_fp.add(b_fp, rm=smt_rm)
+            expected_fp = smt_add_sub(a_fp, b_fp, rm=smt_rm)
         expected = Signal(width)
         m.d.comb += expected.eq(AnyConst(width))
         quiet_bit = 1 << (sort.mantissa_field_width - 1)
@@ -75,74 +79,144 @@ class TestFAddFormal(FHDLTestCase):
     # FIXME: check exception flags
 
     def test_fadd_f16_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, False)
 
     def test_fadd_f32_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, False)
 
     def test_fadd_f16_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, False)
 
     def test_fadd_f32_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, False)
 
     def test_fadd_f16_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, False)
 
     def test_fadd_f32_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, False)
 
     def test_fadd_f16_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, False)
 
     def test_fadd_f32_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, False)
 
     def test_fadd_f16_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, False)
 
     def test_fadd_f32_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, False)
 
     def test_fadd_f16_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, False)
 
     def test_fadd_f32_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, False)
 
     def test_fadd_f16_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, False)
 
     def test_fadd_f32_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, False)
+
+    def test_fsub_f16_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, True)
+
+    def test_fsub_f32_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, True)
+
+    def test_fsub_f16_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, True)
+
+    def test_fsub_f32_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, True)
+
+    def test_fsub_f16_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, True)
+
+    def test_fsub_f32_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, True)
+
+    def test_fsub_f16_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, True)
+
+    def test_fsub_f32_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, True)
+
+    def test_fsub_f16_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, True)
+
+    def test_fsub_f32_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, True)
+
+    def test_fsub_f16_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, True)
+
+    def test_fsub_f32_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, True)
+
+    def test_fsub_f16_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, True)
+
+    def test_fsub_f32_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, True)
 
     def test_all_rounding_modes_covered(self):
         for width in 16, 32, 64:
@@ -150,6 +224,8 @@ class TestFAddFormal(FHDLTestCase):
                 rm_s = rm.name.lower()
                 name = f"test_fadd_f{width}_{rm_s}_formal"
                 assert callable(getattr(self, name))
+                name = f"test_fsub_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
 
 
 if __name__ == '__main__':