Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / fu / branch / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=335
6 * https://libre-soc.org/openpower/isa/branch/
7 """
8
9 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
10 signed, Array, Const)
11 from nmigen.asserts import Assert, AnyConst, Assume, Cover
12 from nmutil.formaltest import FHDLTestCase
13 from nmutil.extend import exts
14 from nmigen.cli import rtlil
15
16 from soc.fu.branch.main_stage import BranchMainStage
17 from soc.fu.branch.pipe_data import BranchPipeSpec
18 from soc.fu.branch.br_input_record import CompBROpSubset
19 from openpower.decoder.power_enums import MicrOp
20 import unittest
21
22
23 # This defines a module to drive the device under test and assert
24 # properties about its outputs
25 class Driver(Elaboratable):
26 def __init__(self):
27 # inputs and outputs
28 pass
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33
34 rec = CompBROpSubset()
35 recwidth = 0
36 # Setup random inputs for dut.op
37 for p in rec.ports():
38 width = p.width
39 recwidth += width
40 comb += p.eq(AnyConst(width))
41
42 pspec = BranchPipeSpec(id_wid=2, parent_pspec=None)
43 m.submodules.dut = dut = BranchMainStage(pspec)
44
45 # convenience aliases
46 op = dut.i.ctx.op
47 cia, cr_in, fast1, fast2 = op.cia, dut.i.cr, dut.i.fast1, dut.i.fast2
48 ctr = fast1
49 lr_o, nia_o = dut.o.lr, dut.o.nia
50
51 comb += op.eq(rec)
52
53 # Assert that op gets copied from the input to output
54 for rec_sig in rec.ports():
55 name = rec_sig.name
56 dut_sig = getattr(op, name)
57 comb += Assert(dut_sig == rec_sig)
58
59 # Full width CR register. Will have bitfield extracted for
60 # feeding to branch unit
61 cr = Signal(32)
62 comb += cr.eq(AnyConst(32))
63 cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)])
64 cr_bit_arr = Array([cr[31-i] for i in range(32)])
65
66 comb += fast2.eq(AnyConst(64))
67 comb += ctr.eq(AnyConst(64))
68
69 i_fields = dut.fields.FormI
70 b_fields = dut.fields.FormB
71 xl_fields = dut.fields.FormXL
72
73 # absolute address mode
74 AA = i_fields.AA[0:-1]
75
76 # Handle CR bit selection
77 BI = b_fields.BI[0:-1]
78 bi = Signal(3, reset_less=True)
79 comb += bi.eq(BI[2:5])
80 comb += dut.i.cr.eq(cr_arr[bi])
81
82 # Handle branch out
83 BO = b_fields.BO[0:-1]
84 bo = Signal(BO.shape())
85 comb += bo.eq(BO)
86 cond_ok = Signal()
87
88 # handle conditional
89 XO = xl_fields.XO[0:-1]
90 xo = Signal(XO.shape())
91 comb += xo.eq(XO)
92
93 # Check CR according to BO
94 comb += cond_ok.eq(bo[4] | (cr_bit_arr[BI] == bo[3]))
95
96 # CTR decrement
97 ctr_next = Signal.like(ctr)
98 with m.If(~BO[2]):
99 comb += ctr_next.eq(ctr - 1)
100 with m.Else():
101 comb += ctr_next.eq(ctr)
102
103 # 32/64 bit CTR
104 ctr_m = Signal.like(ctr)
105 with m.If(rec.is_32bit):
106 comb += ctr_m.eq(ctr[:32])
107 with m.Else():
108 comb += ctr_m.eq(ctr)
109
110 # CTR (32/64 bit) compare with 0
111 ctr_ok = Signal()
112 comb += ctr_ok.eq(BO[2] | ((ctr_m != 0) ^ BO[1]))
113
114 with m.Switch(rec.insn_type):
115
116 ###
117 # b - v3.0B p37
118 ###
119 with m.Case(MicrOp.OP_B):
120 # Extract target address
121 LI = i_fields.LI[0:-1]
122 imm = exts(LI, LI.shape().width, 64-2) * 4
123
124 # Assert that it always branches
125 comb += Assert(nia_o.ok == 1)
126
127 # Check absolute or relative branching
128 with m.If(AA):
129 comb += Assert(nia_o.data == imm)
130 with m.Else():
131 comb += Assert(nia_o.data == (cia + imm)[0:64])
132
133 # Make sure linking works
134 with m.If(rec.lk):
135 comb += Assert(lr_o.data == (cia + 4)[0:64])
136 comb += Assert(lr_o.ok == 1)
137 with m.Else():
138 comb += Assert(lr_o.ok == 0)
139
140 # Assert that ctr is not written to
141 comb += Assert(dut.o.ctr.ok == 0)
142
143 ####
144 # bc - v3.0B p37-38
145 ####
146 with m.Case(MicrOp.OP_BC):
147 # Assert that branches are conditional
148 comb += Assert(nia_o.ok == (cond_ok & ctr_ok))
149
150 # extract target address
151 BD = b_fields.BD[0:-1]
152 imm = exts(BD, BD.shape().width, 64-2) * 4
153
154 # Check absolute or relative branching
155 with m.If(nia_o.ok):
156 with m.If(AA):
157 comb += Assert(nia_o.data == imm)
158 with m.Else():
159 comb += Assert(nia_o.data == (cia + imm)[0:64])
160 comb += Assert(lr_o.ok == rec.lk)
161 with m.If(rec.lk):
162 comb += Assert(lr_o.data == (cia + 4)[0:64])
163
164 # Check that CTR is decremented
165 with m.If(~BO[2]):
166 comb += Assert(dut.o.ctr.data == ctr_next)
167 comb += Assert(dut.o.ctr.ok == 1)
168 with m.Else():
169 comb += Assert(dut.o.ctr.ok == 0)
170
171 ##################
172 # bctar/bcctr/bclr - v3.0B p38-39
173 ##################
174 with m.Case(MicrOp.OP_BCREG):
175 # assert that the condition is good
176 comb += Assert(nia_o.ok == (cond_ok & ctr_ok))
177
178 with m.If(nia_o.ok):
179 # make sure we branch to the spr input
180 with m.If(xo[9] & ~xo[5]):
181 fastext = Cat(Const(0, 2), fast1[2:])
182 comb += Assert(nia_o.data == fastext[0:64])
183 with m.Else():
184 fastext = Cat(Const(0, 2), fast2[2:])
185 comb += Assert(nia_o.data == fastext[0:64])
186
187 # make sure branch+link works
188 comb += Assert(lr_o.ok == rec.lk)
189 with m.If(rec.lk):
190 comb += Assert(lr_o.data == (cia + 4)[0:64])
191
192 # Check that CTR is decremented
193 with m.If(~BO[2]):
194 comb += Assert(dut.o.ctr.data == ctr_next)
195 comb += Assert(dut.o.ctr.ok == 1)
196 with m.Else():
197 comb += Assert(dut.o.ctr.ok == 0)
198 return m
199
200
201 class LogicalTestCase(FHDLTestCase):
202 def test_formal(self):
203 module = Driver()
204 self.assertFormal(module, mode="bmc", depth=2)
205
206 def test_ilang(self):
207 dut = Driver()
208 vl = rtlil.convert(dut, ports=[])
209 with open("main_stage.il", "w") as f:
210 f.write(vl)
211
212
213 if __name__ == '__main__':
214 unittest.main()