Add formal proof for branch unit, fix bug with bcreg
[soc.git] / src / soc / fu / branch / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
5 signed, Array)
6 from nmigen.asserts import Assert, AnyConst, Assume, Cover
7 from nmigen.test.utils import FHDLTestCase
8 from nmutil.extend import exts
9 from nmigen.cli import rtlil
10
11 from soc.fu.branch.main_stage import BranchMainStage
12 from soc.fu.alu.pipe_data import ALUPipeSpec
13 from soc.fu.alu.alu_input_record import CompALUOpSubset
14 from soc.decoder.power_enums import InternalOp
15 import unittest
16
17
18 # This defines a module to drive the device under test and assert
19 # properties about its outputs
20 class Driver(Elaboratable):
21 def __init__(self):
22 # inputs and outputs
23 pass
24
25 def elaborate(self, platform):
26 m = Module()
27 comb = m.d.comb
28
29 rec = CompALUOpSubset()
30 recwidth = 0
31 # Setup random inputs for dut.op
32 for p in rec.ports():
33 width = p.width
34 recwidth += width
35 comb += p.eq(AnyConst(width))
36
37 pspec = ALUPipeSpec(id_wid=2)
38 m.submodules.dut = dut = BranchMainStage(pspec)
39
40 comb += dut.i.ctx.op.eq(rec)
41
42 # Assert that op gets copied from the input to output
43 for rec_sig in rec.ports():
44 name = rec_sig.name
45 dut_sig = getattr(dut.o.ctx.op, name)
46 comb += Assert(dut_sig == rec_sig)
47
48 # Full width CR register. Will have bitfield extracted for
49 # feeding to branch unit
50 cr = Signal(32)
51 comb += cr.eq(AnyConst(32))
52 cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)])
53 cr_bit_arr = Array([cr[31-i] for i in range(32)])
54
55 spr1 = dut.i.spr1
56 ctr = dut.i.spr2
57 cr_in = dut.i.cr
58 cia = dut.i.cia
59
60 comb += [spr1.eq(AnyConst(64)),
61 ctr.eq(AnyConst(64)),
62 cia.eq(AnyConst(64))]
63
64 i_fields = dut.fields.FormI
65 b_fields = dut.fields.FormB
66 AA = i_fields.AA[0:-1]
67 LK = i_fields.LK[0:-1]
68
69 # Handle CR bit selection
70 BI = b_fields.BI[0:-1]
71 bi = Signal(3, reset_less=True)
72 comb += bi.eq(BI[2:5])
73 comb += dut.i.cr.eq(cr_arr[bi])
74
75 # Handle branch out
76 BO = b_fields.BO[0:-1]
77 bo = Signal(BO.shape())
78 comb += bo.eq(BO)
79 cond_ok = Signal()
80
81 # Check CR according to BO
82 comb += cond_ok.eq(bo[4] | (cr_bit_arr[BI] == bo[3]))
83
84 # CTR decrement
85 ctr_next = Signal.like(ctr)
86 with m.If(~BO[2]):
87 comb += ctr_next.eq(ctr - 1)
88 with m.Else():
89 comb += ctr_next.eq(ctr)
90
91 # CTR combpare with 0
92 ctr_ok = Signal()
93 comb += ctr_ok.eq(BO[2] | ((ctr != 0) ^ BO[1]))
94
95 # Sorry, not bothering with 32 bit right now
96 comb += Assume(~rec.is_32bit)
97
98 with m.Switch(rec.insn_type):
99 with m.Case(InternalOp.OP_B):
100 # Extract target address
101 LI = i_fields.LI[0:-1]
102 imm = exts(LI, LI.shape().width, 64-2) * 4
103
104 # Assert that it always branches
105 comb += Assert(dut.o.nia.ok == 1)
106
107 # Check absolute or relative branching
108 with m.If(AA):
109 comb += Assert(dut.o.nia.data == imm)
110 with m.Else():
111 comb += Assert(dut.o.nia.data == (cia + imm)[0:64])
112
113 # Make sure linking works
114 with m.If(LK & rec.lk):
115 comb += Assert(dut.o.lr.data == (cia + 4)[0:64])
116 comb += Assert(dut.o.lr.ok == 1)
117 with m.Case(InternalOp.OP_BC):
118 # Assert that branches are conditional
119 comb += Assert(dut.o.nia.ok == (cond_ok & ctr_ok))
120
121 # extract target address
122 BD = b_fields.BD[0:-1]
123 imm = exts(BD, BD.shape().width, 64-2) * 4
124
125 # Check absolute or relative branching
126 with m.If(dut.o.nia.ok):
127 with m.If(AA):
128 comb += Assert(dut.o.nia.data == imm)
129 with m.Else():
130 comb += Assert(dut.o.nia.data == (cia + imm)[0:64])
131 with m.If(LK & rec.lk):
132 comb += Assert(dut.o.lr.data == (cia + 4)[0:64])
133 comb += Assert(dut.o.lr.ok == 1)
134
135 # Check that CTR is decremented
136 with m.If(~BO[2]):
137 comb += Assert(dut.o.ctr.data == ctr_next)
138 with m.Case(InternalOp.OP_BCREG):
139 # assert that the condition is good
140 comb += Assert(dut.o.nia.ok == (cond_ok & ctr_ok))
141
142 with m.If(dut.o.nia.ok):
143 # make sure we branch to the spr input
144 comb += Assert(dut.o.nia.data == spr1)
145
146 # make sure branch+link works
147 with m.If(LK & rec.lk):
148 comb += Assert(dut.o.lr.data == (cia + 4)[0:64])
149 comb += Assert(dut.o.lr.ok == 1)
150
151 # Check that CTR is decremented
152 with m.If(~BO[2]):
153 comb += Assert(dut.o.ctr.data == ctr_next)
154
155
156 return m
157
158
159 class LogicalTestCase(FHDLTestCase):
160 def test_formal(self):
161 module = Driver()
162 self.assertFormal(module, mode="bmc", depth=2)
163 def test_ilang(self):
164 dut = Driver()
165 vl = rtlil.convert(dut, ports=[])
166 with open("main_stage.il", "w") as f:
167 f.write(vl)
168
169
170 if __name__ == '__main__':
171 unittest.main()