switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_shift / formal / proof_shift_scalar.py
index c1821ddcc77eeef450778bc38d6fc26ba1b868fd..5f0666f8938339c8e4ea5c5bfd468d6b2fb6c50c 100644 (file)
@@ -3,7 +3,7 @@
 
 from nmigen import Module, Signal, Elaboratable, Mux, Cat
 from nmigen.asserts import Assert, AnyConst, Assume
-from nmigen.test.utils import FHDLTestCase
+from nmutil.formaltest import FHDLTestCase
 from nmigen.cli import rtlil
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
@@ -41,12 +41,14 @@ class ShifterDriver(Elaboratable):
         shifter = Signal(shifterwidth)
         points = PartitionPoints()
         gates = Signal(mwidth-1)
+        shift_right = Signal()
         step = int(width/mwidth)
         for i in range(mwidth-1):
             points[(i+1)*step] = gates[i]
         print(points)
 
         comb += [data.eq(AnyConst(width)),
+                 shift_right.eq(AnyConst(1)),
                  shifter.eq(AnyConst(shifterwidth)),
                  gates.eq(AnyConst(mwidth-1))]
 
@@ -57,33 +59,84 @@ class ShifterDriver(Elaboratable):
 
         comb += [dut.data.eq(data),
                  dut.shifter.eq(shifter),
+                 dut.shift_right.eq(shift_right),
                  out.eq(dut.output)]
 
         expected = Signal(width)
-        comb += expected.eq(data << shifter)
 
-        with m.Switch(points.as_sig()):
-            with m.Case(0b00):
-                comb += Assert(out[0:24] == (data[0:24] << shifter) & 0xffffff)
-
-            with m.Case(0b01):
-                comb += Assert(out[0:8] == expected[0:8])
-                comb += Assert(out[8:24] == (data[8:24] << shifter) & 0xffff)
-
-            with m.Case(0b10):
-                comb += Assert(out[16:24] == (data[16:24] << shifter) & 0xff)
-                comb += Assert(out[0:16] == (data[0:16] << shifter) & 0xffff)
-
-            with m.Case(0b11):
-                comb += Assert(out[0:8] == expected[0:8])
-                comb += Assert(out[8:16] == (data[8:16] << shifter) & 0xff)
-                comb += Assert(out[16:24] == (data[16:24] << shifter) & 0xff)
+        with m.If(shift_right == 0):
+            with m.Switch(points.as_sig()):
+                with m.Case(0b00):
+                    comb += Assert(
+                        out[0:24] == (data[0:24] << (shifter & 0x1f)) &
+                        0xffffff)
+
+                with m.Case(0b01):
+                    comb += Assert(out[0:8] ==
+                                (data[0:8] << (shifter & 0x7)) & 0xFF)
+                    comb += Assert(out[8:24] ==
+                                (data[8:24] << (shifter & 0xf)) & 0xffff)
+
+                with m.Case(0b10):
+                    comb += Assert(out[16:24] ==
+                                (data[16:24] << (shifter & 0x7)) & 0xff)
+                    comb += Assert(out[0:16] ==
+                                (data[0:16] << (shifter & 0xf)) & 0xffff)
+
+                with m.Case(0b11):
+                    comb += Assert(out[0:8] ==
+                                (data[0:8] << (shifter & 0x7)) & 0xFF)
+                    comb += Assert(out[8:16] ==
+                                (data[8:16] << (shifter & 0x7)) & 0xff)
+                    comb += Assert(out[16:24] ==
+                                (data[16:24] << (shifter & 0x7)) & 0xff)
+        with m.Else():
+            with m.Switch(points.as_sig()):
+                with m.Case(0b00):
+                    comb += Assert(
+                        out[0:24] == (data[0:24] >> (shifter & 0x1f)) &
+                        0xffffff)
+
+                with m.Case(0b01):
+                    comb += Assert(out[0:8] ==
+                                (data[0:8] >> (shifter & 0x7)) & 0xFF)
+                    comb += Assert(out[8:24] ==
+                                (data[8:24] >> (shifter & 0xf)) & 0xffff)
+
+                with m.Case(0b10):
+                    comb += Assert(out[16:24] ==
+                                (data[16:24] >> (shifter & 0x7)) & 0xff)
+                    comb += Assert(out[0:16] ==
+                                (data[0:16] >> (shifter & 0xf)) & 0xffff)
+
+                with m.Case(0b11):
+                    comb += Assert(out[0:8] ==
+                                (data[0:8] >> (shifter & 0x7)) & 0xFF)
+                    comb += Assert(out[8:16] ==
+                                (data[8:16] >> (shifter & 0x7)) & 0xff)
+                    comb += Assert(out[16:24] ==
+                                (data[16:24] >> (shifter & 0x7)) & 0xff)
         return m
 
 class PartitionedScalarShiftTestCase(FHDLTestCase):
     def test_shift(self):
         module = ShifterDriver()
         self.assertFormal(module, mode="bmc", depth=4)
+    def test_ilang(self):
+        width = 24
+        mwidth = 3
+        gates = Signal(mwidth-1)
+        points = PartitionPoints()
+        step = int(width/mwidth)
+        for i in range(mwidth-1):
+            points[(i+1)*step] = gates[i]
+        print(points)
+        dut = PartitionedScalarShift(width, points)
+        vl = rtlil.convert(dut, ports=[gates, dut.data,
+                                       dut.shifter,
+                                       dut.output])
+        with open("scalar_shift.il", "w") as f:
+            f.write(vl)
 
 if __name__ == "__main__":
     unittest.main()