Add formal proof for branch unit, fix bug with bcreg
authorMichael Nolan <mtnolan2640@gmail.com>
Fri, 22 May 2020 18:20:13 +0000 (14:20 -0400)
committerMichael Nolan <mtnolan2640@gmail.com>
Fri, 22 May 2020 18:21:49 +0000 (14:21 -0400)
src/soc/fu/branch/formal/proof_main_stage.py
src/soc/fu/branch/main_stage.py

index 804643df5a01160c6996bae3fe7d8c89133ecf16..0878efedaa75d03d846a78142dfbd7e77b854594 100644 (file)
@@ -2,12 +2,13 @@
 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
 
 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
 
 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
-                    signed)
+                    signed, Array)
 from nmigen.asserts import Assert, AnyConst, Assume, Cover
 from nmigen.test.utils import FHDLTestCase
 from nmigen.asserts import Assert, AnyConst, Assume, Cover
 from nmigen.test.utils import FHDLTestCase
+from nmutil.extend import exts
 from nmigen.cli import rtlil
 
 from nmigen.cli import rtlil
 
-from soc.fu.logical.main_stage import LogicalMainStage
+from soc.fu.branch.main_stage import BranchMainStage
 from soc.fu.alu.pipe_data import ALUPipeSpec
 from soc.fu.alu.alu_input_record import CompALUOpSubset
 from soc.decoder.power_enums import InternalOp
 from soc.fu.alu.pipe_data import ALUPipeSpec
 from soc.fu.alu.alu_input_record import CompALUOpSubset
 from soc.decoder.power_enums import InternalOp
@@ -33,22 +34,8 @@ class Driver(Elaboratable):
             recwidth += width
             comb += p.eq(AnyConst(width))
 
             recwidth += width
             comb += p.eq(AnyConst(width))
 
-        pspec = ALUPipeSpec(id_wid=2, op_wid=recwidth)
-        m.submodules.dut = dut = LogicalMainStage(pspec)
-
-        # convenience variables
-        a = dut.i.a
-        b = dut.i.b
-        carry_in = dut.i.carry_in
-        so_in = dut.i.so
-        carry_out = dut.o.carry_out
-        o = dut.o.o
-
-        # setup random inputs
-        comb += [a.eq(AnyConst(64)),
-                 b.eq(AnyConst(64)),
-                 carry_in.eq(AnyConst(1)),
-                 so_in.eq(AnyConst(1))]
+        pspec = ALUPipeSpec(id_wid=2)
+        m.submodules.dut = dut = BranchMainStage(pspec)
 
         comb += dut.i.ctx.op.eq(rec)
 
 
         comb += dut.i.ctx.op.eq(rec)
 
@@ -58,20 +45,113 @@ class Driver(Elaboratable):
             dut_sig = getattr(dut.o.ctx.op, name)
             comb += Assert(dut_sig == rec_sig)
 
             dut_sig = getattr(dut.o.ctx.op, name)
             comb += Assert(dut_sig == rec_sig)
 
-        # signed and signed/32 versions of input a
-        a_signed = Signal(signed(64))
-        a_signed_32 = Signal(signed(32))
-        comb += a_signed.eq(a)
-        comb += a_signed_32.eq(a[0:32])
+        # Full width CR register. Will have bitfield extracted for
+        # feeding to branch unit
+        cr = Signal(32)
+        comb += cr.eq(AnyConst(32))
+        cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)])
+        cr_bit_arr = Array([cr[31-i] for i in range(32)])
+
+        spr1 = dut.i.spr1
+        ctr = dut.i.spr2
+        cr_in = dut.i.cr
+        cia = dut.i.cia
+
+        comb += [spr1.eq(AnyConst(64)),
+                 ctr.eq(AnyConst(64)),
+                 cia.eq(AnyConst(64))]
+
+        i_fields = dut.fields.FormI
+        b_fields = dut.fields.FormB
+        AA = i_fields.AA[0:-1]
+        LK = i_fields.LK[0:-1]
+
+        # Handle CR bit selection
+        BI = b_fields.BI[0:-1]
+        bi = Signal(3, reset_less=True)
+        comb += bi.eq(BI[2:5])
+        comb += dut.i.cr.eq(cr_arr[bi])
+
+        # Handle branch out
+        BO = b_fields.BO[0:-1]
+        bo = Signal(BO.shape())
+        comb += bo.eq(BO)
+        cond_ok = Signal()
+
+        # Check CR according to BO
+        comb += cond_ok.eq(bo[4] | (cr_bit_arr[BI] == bo[3]))
+
+        # CTR decrement
+        ctr_next = Signal.like(ctr)
+        with m.If(~BO[2]):
+            comb += ctr_next.eq(ctr - 1)
+        with m.Else():
+            comb += ctr_next.eq(ctr)
+
+        # CTR combpare with 0
+        ctr_ok = Signal()
+        comb += ctr_ok.eq(BO[2] | ((ctr != 0) ^ BO[1]))
+        
+        # Sorry, not bothering with 32 bit right now
+        comb += Assume(~rec.is_32bit)
 
 
-        # main assertion of arithmetic operations
         with m.Switch(rec.insn_type):
         with m.Switch(rec.insn_type):
-            with m.Case(InternalOp.OP_AND):
-                comb += Assert(dut.o.o == a & b)
-            with m.Case(InternalOp.OP_OR):
-                comb += Assert(dut.o.o == a | b)
-            with m.Case(InternalOp.OP_XOR):
-                comb += Assert(dut.o.o == a ^ b)
+            with m.Case(InternalOp.OP_B):
+                # Extract target address
+                LI = i_fields.LI[0:-1]
+                imm = exts(LI, LI.shape().width, 64-2) * 4
+
+                # Assert that it always branches
+                comb += Assert(dut.o.nia.ok == 1)
+
+                # Check absolute or relative branching
+                with m.If(AA):
+                    comb += Assert(dut.o.nia.data == imm)
+                with m.Else():
+                    comb += Assert(dut.o.nia.data == (cia + imm)[0:64])
+
+                # Make sure linking works
+                with m.If(LK & rec.lk):
+                    comb += Assert(dut.o.lr.data == (cia + 4)[0:64])
+                    comb += Assert(dut.o.lr.ok == 1)
+            with m.Case(InternalOp.OP_BC):
+                # Assert that branches are conditional
+                comb += Assert(dut.o.nia.ok == (cond_ok & ctr_ok))
+
+                # extract target address
+                BD = b_fields.BD[0:-1]
+                imm = exts(BD, BD.shape().width, 64-2) * 4
+
+                # Check absolute or relative branching
+                with m.If(dut.o.nia.ok):
+                    with m.If(AA):
+                        comb += Assert(dut.o.nia.data == imm)
+                    with m.Else():
+                        comb += Assert(dut.o.nia.data == (cia + imm)[0:64])
+                    with m.If(LK & rec.lk):
+                        comb += Assert(dut.o.lr.data == (cia + 4)[0:64])
+                        comb += Assert(dut.o.lr.ok == 1)
+
+                # Check that CTR is decremented
+                with m.If(~BO[2]):
+                    comb += Assert(dut.o.ctr.data == ctr_next)
+            with m.Case(InternalOp.OP_BCREG):
+                # assert that the condition is good
+                comb += Assert(dut.o.nia.ok == (cond_ok & ctr_ok))
+
+                with m.If(dut.o.nia.ok):
+                    # make sure we branch to the spr input
+                    comb += Assert(dut.o.nia.data == spr1)
+
+                    # make sure branch+link works
+                    with m.If(LK & rec.lk):
+                        comb += Assert(dut.o.lr.data == (cia + 4)[0:64])
+                        comb += Assert(dut.o.lr.ok == 1)
+
+                # Check that CTR is decremented
+                with m.If(~BO[2]):
+                    comb += Assert(dut.o.ctr.data == ctr_next)
+
 
         return m
 
 
         return m
 
@@ -80,7 +160,6 @@ class LogicalTestCase(FHDLTestCase):
     def test_formal(self):
         module = Driver()
         self.assertFormal(module, mode="bmc", depth=2)
     def test_formal(self):
         module = Driver()
         self.assertFormal(module, mode="bmc", depth=2)
-        self.assertFormal(module, mode="cover", depth=2)
     def test_ilang(self):
         dut = Driver()
         vl = rtlil.convert(dut, ports=[])
     def test_ilang(self):
         dut = Driver()
         vl = rtlil.convert(dut, ports=[])
index 636e581c4f5d17ab12912ebd504d1fe8f57cb393..d4eecf37f80d3ec1ed93007d11d0aec5d9c828f3 100644 (file)
@@ -64,7 +64,7 @@ class BranchMainStage(PipeModBase):
         br_taken = Signal(reset_less=True)
 
         # Handle absolute or relative branches
         br_taken = Signal(reset_less=True)
 
         # Handle absolute or relative branches
-        with m.If(AA):
+        with m.If(AA | (op.insn_type == InternalOp.OP_BCREG)):
             comb += br_addr.eq(br_imm_addr)
         with m.Else():
             comb += br_addr.eq(br_imm_addr + cia)
             comb += br_addr.eq(br_imm_addr)
         with m.Else():
             comb += br_addr.eq(br_imm_addr + cia)