rework branch proof to use br_input_record
[soc.git] / src / soc / fu / branch / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=335
6 * https://libre-soc.org/openpower/isa/branch/
7 """
8
9 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
10 signed, Array)
11 from nmigen.asserts import Assert, AnyConst, Assume, Cover
12 from nmutil.formaltest import FHDLTestCase
13 from nmutil.extend import exts
14 from nmigen.cli import rtlil
15
16 from soc.fu.branch.main_stage import BranchMainStage
17 from soc.fu.branch.pipe_data import BranchPipeSpec
18 from soc.fu.branch.br_input_record import CompBROpSubset
19 from soc.decoder.power_enums import MicrOp
20 import unittest
21
22
23 # This defines a module to drive the device under test and assert
24 # properties about its outputs
25 class Driver(Elaboratable):
26 def __init__(self):
27 # inputs and outputs
28 pass
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33
34 rec = CompBROpSubset()
35 recwidth = 0
36 # Setup random inputs for dut.op
37 for p in rec.ports():
38 width = p.width
39 recwidth += width
40 comb += p.eq(AnyConst(width))
41
42 pspec = BranchPipeSpec(id_wid=2)
43 m.submodules.dut = dut = BranchMainStage(pspec)
44
45 # convenience aliases
46 op = dut.i.ctx.op
47 cia, cr_in, fast1, fast2 = op.cia, dut.i.cr, dut.i.fast1, dut.i.fast2
48 ctr = fast1
49 lr_o, nia_o = dut.o.lr, dut.o.nia
50
51 comb += op.eq(rec)
52
53 # Assert that op gets copied from the input to output
54 for rec_sig in rec.ports():
55 name = rec_sig.name
56 dut_sig = getattr(op, name)
57 comb += Assert(dut_sig == rec_sig)
58
59 # Full width CR register. Will have bitfield extracted for
60 # feeding to branch unit
61 cr = Signal(32)
62 comb += cr.eq(AnyConst(32))
63 cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)])
64 cr_bit_arr = Array([cr[31-i] for i in range(32)])
65
66 comb += [fast2.eq(AnyConst(64)),
67 ctr.eq(AnyConst(64)),
68 ]
69
70 i_fields = dut.fields.FormI
71 b_fields = dut.fields.FormB
72 AA = i_fields.AA[0:-1]
73
74 # Handle CR bit selection
75 BI = b_fields.BI[0:-1]
76 bi = Signal(3, reset_less=True)
77 comb += bi.eq(BI[2:5])
78 comb += dut.i.cr.eq(cr_arr[bi])
79
80 # Handle branch out
81 BO = b_fields.BO[0:-1]
82 bo = Signal(BO.shape())
83 comb += bo.eq(BO)
84 cond_ok = Signal()
85
86 # Check CR according to BO
87 comb += cond_ok.eq(bo[4] | (cr_bit_arr[BI] == bo[3]))
88
89 # CTR decrement
90 ctr_next = Signal.like(ctr)
91 with m.If(~BO[2]):
92 comb += ctr_next.eq(ctr - 1)
93 with m.Else():
94 comb += ctr_next.eq(ctr)
95
96 # CTR combpare with 0
97 ctr_ok = Signal()
98 comb += ctr_ok.eq(BO[2] | ((ctr != 0) ^ BO[1]))
99
100 # Sorry, not bothering with 32 bit right now
101 comb += Assume(~rec.is_32bit)
102
103 with m.Switch(rec.insn_type):
104
105 #### b ####
106 with m.Case(MicrOp.OP_B):
107 # Extract target address
108 LI = i_fields.LI[0:-1]
109 imm = exts(LI, LI.shape().width, 64-2) * 4
110
111 # Assert that it always branches
112 comb += Assert(nia_o.ok == 1)
113
114 # Check absolute or relative branching
115 with m.If(AA):
116 comb += Assert(nia_o.data == imm)
117 with m.Else():
118 comb += Assert(nia_o.data == (cia + imm)[0:64])
119
120 # Make sure linking works
121 with m.If(rec.lk):
122 comb += Assert(lr_o.data == (cia + 4)[0:64])
123 comb += Assert(lr_o.ok == 1)
124 with m.Else():
125 comb += Assert(lr_o.ok == 0)
126
127 # Assert that ctr is not written to
128 comb += Assert(dut.o.ctr.ok == 0)
129
130 #### bc ####
131 with m.Case(MicrOp.OP_BC):
132 # Assert that branches are conditional
133 comb += Assert(nia_o.ok == (cond_ok & ctr_ok))
134
135 # extract target address
136 BD = b_fields.BD[0:-1]
137 imm = exts(BD, BD.shape().width, 64-2) * 4
138
139 # Check absolute or relative branching
140 with m.If(nia_o.ok):
141 with m.If(AA):
142 comb += Assert(nia_o.data == imm)
143 with m.Else():
144 comb += Assert(nia_o.data == (cia + imm)[0:64])
145 comb += Assert(lr_o.ok == rec.lk)
146 with m.If(rec.lk):
147 comb += Assert(lr_o.data == (cia + 4)[0:64])
148
149 # Check that CTR is decremented
150 with m.If(~BO[2]):
151 comb += Assert(dut.o.ctr.data == ctr_next)
152 comb += Assert(dut.o.ctr.ok == 1)
153 with m.Else():
154 comb += Assert(dut.o.ctr.ok == 0)
155 #### bctar/bcctr/bclr ####
156 with m.Case(MicrOp.OP_BCREG):
157 # assert that the condition is good
158 comb += Assert(nia_o.ok == (cond_ok & ctr_ok))
159
160 with m.If(nia_o.ok):
161 # make sure we branch to the spr input
162 comb += Assert(nia_o.data == fast1)
163
164 # make sure branch+link works
165 comb += Assert(lr_o.ok == rec.lk)
166 with m.If(rec.lk):
167 comb += Assert(lr_o.data == (cia + 4)[0:64])
168
169 # Check that CTR is decremented
170 with m.If(~BO[2]):
171 comb += Assert(dut.o.ctr.data == ctr_next)
172 comb += Assert(dut.o.ctr.ok == 1)
173 with m.Else():
174 comb += Assert(dut.o.ctr.ok == 0)
175
176 return m
177
178
179 class LogicalTestCase(FHDLTestCase):
180 def test_formal(self):
181 module = Driver()
182 self.assertFormal(module, mode="bmc", depth=2)
183 def test_ilang(self):
184 dut = Driver()
185 vl = rtlil.convert(dut, ports=[])
186 with open("main_stage.il", "w") as f:
187 f.write(vl)
188
189
190 if __name__ == '__main__':
191 unittest.main()