re-add masking of the shift amount based on partition length
[ieee754fpu.git] / src / ieee754 / part_shift / formal / proof_shift_dynamic.py
index d4e7a525282b7ab7a4e753ed7939202074b5dcd2..a836771c2262ebd0bf24f2d720561a5367257708 100644 (file)
@@ -2,7 +2,7 @@
 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
 
 from nmigen import Module, Signal, Elaboratable, Mux, Cat
-from nmigen.asserts import Assert, AnyConst, Assume
+from nmigen.asserts import Assert, AnyConst
 from nmigen.test.utils import FHDLTestCase
 from nmigen.cli import rtlil
 
@@ -63,71 +63,55 @@ class ShifterDriver(Elaboratable):
 
         with m.Switch(points.as_sig()):
             with m.Case(0b000):
-                comb += Assume(b <= 32)
-                comb += Assert(out == (a<<b[0:6]) & 0xffffffff)
+                comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
             with m.Case(0b001):
-                comb += Assume(b_intervals[0] <= 8)
                 comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0]) & 0xff)
-                comb += Assume(b_intervals[1] <= 24)
+                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
                 comb += Assert(Cat(out_intervals[1:4]) ==
                                (Cat(a_intervals[1:4])
-                                << b_intervals[1]) & 0xffffff)
+                                << b_intervals[1][0:5]) & 0xffffff)
             with m.Case(0b010):
-                comb += Assume(b_intervals[0] <= 16)
                 comb += Assert(Cat(out_intervals[0:2]) ==
                                (Cat(a_intervals[0:2])
-                                << b_intervals[0]) & 0xffff)
-                comb += Assume(b_intervals[2] <= 16)
+                                << (b_intervals[0] & 0xf)) & 0xffff)
                 comb += Assert(Cat(out_intervals[2:4]) ==
                                (Cat(a_intervals[2:4])
-                                << b_intervals[2]) & 0xffff)
+                                << (b_intervals[2] & 0xf)) & 0xffff)
             with m.Case(0b011):
-                comb += Assume(b_intervals[0] <= 8)
                 comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0]) & 0xff)
-                comb += Assume(b_intervals[1] <= 8)
+                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
                 comb += Assert(out_intervals[1] ==
-                               (a_intervals[1] << b_intervals[1]) & 0xff)
-                comb += Assume(b_intervals[2] <= 16)
+                               (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
                 comb += Assert(Cat(out_intervals[2:4]) ==
                                (Cat(a_intervals[2:4])
-                                << b_intervals[2]) & 0xffff)
+                                << b_intervals[2][0:4]) & 0xffff)
             with m.Case(0b100):
-                comb += Assume(b_intervals[0] <= 24)
                 comb += Assert(Cat(out_intervals[0:3]) ==
                                (Cat(a_intervals[0:3])
-                                << b_intervals[0]) & 0xffffff)
-                comb += Assume(b_intervals[3] <= 8)
+                                << b_intervals[0][0:5]) & 0xffffff)
                 comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3]) & 0xff)
+                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
             with m.Case(0b101):
-                comb += Assume(b_intervals[0] <= 8)
                 comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0]) & 0xff)
-                comb += Assume(b_intervals[1] <= 16)
+                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
                 comb += Assert(Cat(out_intervals[1:3]) ==
                                (Cat(a_intervals[1:3])
-                                << b_intervals[1]) & 0xffff)
-                comb += Assume(b_intervals[3] <= 8)
+                                << b_intervals[1][0:4]) & 0xffff)
                 comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3]) & 0xff)
+                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
             with m.Case(0b110):
-                comb += Assume(b_intervals[0] <= 16)
                 comb += Assert(Cat(out_intervals[0:2]) ==
                                (Cat(a_intervals[0:2])
-                                << b_intervals[0]) & 0xffff)
-                comb += Assume(b_intervals[2] <= 8)
+                                << b_intervals[0][0:4]) & 0xffff)
                 comb += Assert(out_intervals[2] ==
-                               (a_intervals[2] << b_intervals[2]) & 0xff)
-                comb += Assume(b_intervals[3] <= 8)
+                               (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
                 comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3]) & 0xff)
+                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
             with m.Case(0b111):
                 for i, o in enumerate(out_intervals):
-                    comb += Assume(b_intervals[i] <= 8)
                     comb += Assert(o ==
-                                   (a_intervals[i] << b_intervals[i]) & 0xff)
+                                   (a_intervals[i] << b_intervals[i][0:3])
+                                   & 0xff)
 
         return m
 
@@ -135,6 +119,7 @@ class PartitionedDynamicShiftTestCase(FHDLTestCase):
     def test_shift(self):
         module = ShifterDriver()
         self.assertFormal(module, mode="bmc", depth=4)
+
     def test_ilang(self):
         width = 64
         mwidth = 8
@@ -152,4 +137,3 @@ class PartitionedDynamicShiftTestCase(FHDLTestCase):
 
 if __name__ == "__main__":
     unittest.main()
-