From 897123c76b338d42b4d12f923e090e16f3c73f07 Mon Sep 17 00:00:00 2001 From: Michael Nolan Date: Fri, 22 May 2020 14:20:13 -0400 Subject: [PATCH] Add formal proof for branch unit, fix bug with bcreg --- src/soc/fu/branch/formal/proof_main_stage.py | 141 +++++++++++++++---- src/soc/fu/branch/main_stage.py | 2 +- 2 files changed, 111 insertions(+), 32 deletions(-) diff --git a/src/soc/fu/branch/formal/proof_main_stage.py b/src/soc/fu/branch/formal/proof_main_stage.py index 804643df..0878efed 100644 --- a/src/soc/fu/branch/formal/proof_main_stage.py +++ b/src/soc/fu/branch/formal/proof_main_stage.py @@ -2,12 +2,13 @@ # Copyright (C) 2020 Michael Nolan from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl, - signed) + signed, Array) from nmigen.asserts import Assert, AnyConst, Assume, Cover from nmigen.test.utils import FHDLTestCase +from nmutil.extend import exts from nmigen.cli import rtlil -from soc.fu.logical.main_stage import LogicalMainStage +from soc.fu.branch.main_stage import BranchMainStage from soc.fu.alu.pipe_data import ALUPipeSpec from soc.fu.alu.alu_input_record import CompALUOpSubset from soc.decoder.power_enums import InternalOp @@ -33,22 +34,8 @@ class Driver(Elaboratable): recwidth += width comb += p.eq(AnyConst(width)) - pspec = ALUPipeSpec(id_wid=2, op_wid=recwidth) - m.submodules.dut = dut = LogicalMainStage(pspec) - - # convenience variables - a = dut.i.a - b = dut.i.b - carry_in = dut.i.carry_in - so_in = dut.i.so - carry_out = dut.o.carry_out - o = dut.o.o - - # setup random inputs - comb += [a.eq(AnyConst(64)), - b.eq(AnyConst(64)), - carry_in.eq(AnyConst(1)), - so_in.eq(AnyConst(1))] + pspec = ALUPipeSpec(id_wid=2) + m.submodules.dut = dut = BranchMainStage(pspec) comb += dut.i.ctx.op.eq(rec) @@ -58,20 +45,113 @@ class Driver(Elaboratable): dut_sig = getattr(dut.o.ctx.op, name) comb += Assert(dut_sig == rec_sig) - # signed and signed/32 versions of input a - a_signed = Signal(signed(64)) - a_signed_32 = Signal(signed(32)) - comb += a_signed.eq(a) - comb += a_signed_32.eq(a[0:32]) + # Full width CR register. Will have bitfield extracted for + # feeding to branch unit + cr = Signal(32) + comb += cr.eq(AnyConst(32)) + cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)]) + cr_bit_arr = Array([cr[31-i] for i in range(32)]) + + spr1 = dut.i.spr1 + ctr = dut.i.spr2 + cr_in = dut.i.cr + cia = dut.i.cia + + comb += [spr1.eq(AnyConst(64)), + ctr.eq(AnyConst(64)), + cia.eq(AnyConst(64))] + + i_fields = dut.fields.FormI + b_fields = dut.fields.FormB + AA = i_fields.AA[0:-1] + LK = i_fields.LK[0:-1] + + # Handle CR bit selection + BI = b_fields.BI[0:-1] + bi = Signal(3, reset_less=True) + comb += bi.eq(BI[2:5]) + comb += dut.i.cr.eq(cr_arr[bi]) + + # Handle branch out + BO = b_fields.BO[0:-1] + bo = Signal(BO.shape()) + comb += bo.eq(BO) + cond_ok = Signal() + + # Check CR according to BO + comb += cond_ok.eq(bo[4] | (cr_bit_arr[BI] == bo[3])) + + # CTR decrement + ctr_next = Signal.like(ctr) + with m.If(~BO[2]): + comb += ctr_next.eq(ctr - 1) + with m.Else(): + comb += ctr_next.eq(ctr) + + # CTR combpare with 0 + ctr_ok = Signal() + comb += ctr_ok.eq(BO[2] | ((ctr != 0) ^ BO[1])) + + # Sorry, not bothering with 32 bit right now + comb += Assume(~rec.is_32bit) - # main assertion of arithmetic operations with m.Switch(rec.insn_type): - with m.Case(InternalOp.OP_AND): - comb += Assert(dut.o.o == a & b) - with m.Case(InternalOp.OP_OR): - comb += Assert(dut.o.o == a | b) - with m.Case(InternalOp.OP_XOR): - comb += Assert(dut.o.o == a ^ b) + with m.Case(InternalOp.OP_B): + # Extract target address + LI = i_fields.LI[0:-1] + imm = exts(LI, LI.shape().width, 64-2) * 4 + + # Assert that it always branches + comb += Assert(dut.o.nia.ok == 1) + + # Check absolute or relative branching + with m.If(AA): + comb += Assert(dut.o.nia.data == imm) + with m.Else(): + comb += Assert(dut.o.nia.data == (cia + imm)[0:64]) + + # Make sure linking works + with m.If(LK & rec.lk): + comb += Assert(dut.o.lr.data == (cia + 4)[0:64]) + comb += Assert(dut.o.lr.ok == 1) + with m.Case(InternalOp.OP_BC): + # Assert that branches are conditional + comb += Assert(dut.o.nia.ok == (cond_ok & ctr_ok)) + + # extract target address + BD = b_fields.BD[0:-1] + imm = exts(BD, BD.shape().width, 64-2) * 4 + + # Check absolute or relative branching + with m.If(dut.o.nia.ok): + with m.If(AA): + comb += Assert(dut.o.nia.data == imm) + with m.Else(): + comb += Assert(dut.o.nia.data == (cia + imm)[0:64]) + with m.If(LK & rec.lk): + comb += Assert(dut.o.lr.data == (cia + 4)[0:64]) + comb += Assert(dut.o.lr.ok == 1) + + # Check that CTR is decremented + with m.If(~BO[2]): + comb += Assert(dut.o.ctr.data == ctr_next) + with m.Case(InternalOp.OP_BCREG): + # assert that the condition is good + comb += Assert(dut.o.nia.ok == (cond_ok & ctr_ok)) + + with m.If(dut.o.nia.ok): + # make sure we branch to the spr input + comb += Assert(dut.o.nia.data == spr1) + + # make sure branch+link works + with m.If(LK & rec.lk): + comb += Assert(dut.o.lr.data == (cia + 4)[0:64]) + comb += Assert(dut.o.lr.ok == 1) + + # Check that CTR is decremented + with m.If(~BO[2]): + comb += Assert(dut.o.ctr.data == ctr_next) + return m @@ -80,7 +160,6 @@ class LogicalTestCase(FHDLTestCase): def test_formal(self): module = Driver() self.assertFormal(module, mode="bmc", depth=2) - self.assertFormal(module, mode="cover", depth=2) def test_ilang(self): dut = Driver() vl = rtlil.convert(dut, ports=[]) diff --git a/src/soc/fu/branch/main_stage.py b/src/soc/fu/branch/main_stage.py index 636e581c..d4eecf37 100644 --- a/src/soc/fu/branch/main_stage.py +++ b/src/soc/fu/branch/main_stage.py @@ -64,7 +64,7 @@ class BranchMainStage(PipeModBase): br_taken = Signal(reset_less=True) # Handle absolute or relative branches - with m.If(AA): + with m.If(AA | (op.insn_type == InternalOp.OP_BCREG)): comb += br_addr.eq(br_imm_addr) with m.Else(): comb += br_addr.eq(br_imm_addr + cia) -- 2.30.2