use temporary python vars rather than copy signals (shorter code)
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 10 May 2020 05:42:43 +0000 (06:42 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 10 May 2020 05:42:43 +0000 (06:42 +0100)
src/soc/alu/formal/proof_main_stage.py

index 19100079c9ae5ccacf4352bc1b7bd453e5b55a9e..df73815b745f246cdc7ebec0cf4be537580b296c 100644 (file)
@@ -36,39 +36,38 @@ class Driver(Elaboratable):
         pspec = ALUPipeSpec(id_wid=2, op_wid=recwidth)
         m.submodules.dut = dut = ALUMainStage(pspec)
 
-        a = Signal(64)
-        b = Signal(64)
-        carry_in = Signal()
-        so_in = Signal()
-        comb += [dut.i.a.eq(a),
-                 dut.i.b.eq(b),
-                 dut.i.carry_in.eq(carry_in),
-                 dut.i.so.eq(so_in),
-                 a.eq(AnyConst(64)),
+        # convenience variables
+        a = dut.i.a
+        b = dut.i.b
+        carry_in = dut.i.carry_in
+        so_in = dut.i.so
+        carry_out = dut.o.carry_out
+        o = dut.o.o
+
+        # setup random inputs
+        comb += [a.eq(AnyConst(64)),
                  b.eq(AnyConst(64)),
                  carry_in.eq(AnyConst(1)),
                  so_in.eq(AnyConst(1))]
-                      
 
         comb += dut.i.ctx.op.eq(rec)
 
-
         # Assert that op gets copied from the input to output
-        for p in rec.ports():
-            name = p.name
-            rec_sig = p
+        for rec_sig in rec.ports():
+            name = rec_sig.name
             dut_sig = getattr(dut.o.ctx.op, name)
             comb += Assert(dut_sig == rec_sig)
 
+        # signed and signed/32 versions of input a
         a_signed = Signal(signed(64))
-        comb += a_signed.eq(a)
         a_signed_32 = Signal(signed(32))
+        comb += a_signed.eq(a)
         comb += a_signed_32.eq(a[0:32])
 
+        # main assertion of arithmetic operations
         with m.Switch(rec.insn_type):
             with m.Case(InternalOp.OP_ADD):
-                comb += Assert(Cat(dut.o.o, dut.o.carry_out) ==
-                               (a + b + carry_in))
+                comb += Assert(Cat(o, carry_out) == (a + b + carry_in))
             with m.Case(InternalOp.OP_AND):
                 comb += Assert(dut.o.o == a & b)
             with m.Case(InternalOp.OP_OR):
@@ -77,32 +76,28 @@ class Driver(Elaboratable):
                 comb += Assert(dut.o.o == a ^ b)
             with m.Case(InternalOp.OP_SHL):
                 with m.If(rec.is_32bit):
-                    comb += Assert(dut.o.o[0:32] == ((a << b[0:6]) &
-                                                     0xffffffff))
-                    comb += Assert(dut.o.o[32:64] == 0)
+                    comb += Assert(o[0:32] == ((a << b[0:6]) & 0xffffffff))
+                    comb += Assert(o[32:64] == 0)
                 with m.Else():
-                    comb += Assert(dut.o.o == ((a << b[0:7]) &
-                                               ((1 << 64)-1)))
+                    comb += Assert(o == ((a << b[0:7]) & ((1 << 64)-1)))
             with m.Case(InternalOp.OP_SHR):
                 with m.If(~rec.is_signed):
                     with m.If(rec.is_32bit):
-                        comb += Assert(dut.o.o[0:32] ==
-                                       (a[0:32] >> b[0:6]))
-                        comb += Assert(dut.o.o[32:64] == 0)
+                        comb += Assert(o[0:32] == (a[0:32] >> b[0:6]))
+                        comb += Assert(o[32:64] == 0)
                     with m.Else():
-                        comb += Assert(dut.o.o == (a >> b[0:7]))
+                        comb += Assert(o == (a >> b[0:7]))
                 with m.Else():
                     with m.If(rec.is_32bit):
-                        comb += Assert(dut.o.o[0:32] ==
-                                       (a_signed_32 >> b[0:6]))
-                        comb += Assert(dut.o.o[32:64] == Repl(a[31], 32))
+                        comb += Assert(o[0:32] == (a_signed_32 >> b[0:6]))
+                        comb += Assert(o[32:64] == Repl(a[31], 32))
                     with m.Else():
-                        comb += Assert(dut.o.o == (a_signed >> b[0:7]))
-
+                        comb += Assert(o == (a_signed >> b[0:7]))
 
         return m
 
-class GTCombinerTestCase(FHDLTestCase):
+
+class ALUTestCase(FHDLTestCase):
     def test_formal(self):
         module = Driver()
         self.assertFormal(module, mode="bmc", depth=4)