rename InternalOp to MicrOp
[soc.git] / src / soc / fu / branch / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=335
6 * https://libre-soc.org/openpower/isa/branch/
7 """
8
9 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
10 signed, Array)
11 from nmigen.asserts import Assert, AnyConst, Assume, Cover
12 from nmutil.formaltest import FHDLTestCase
13 from nmutil.extend import exts
14 from nmigen.cli import rtlil
15
16 from soc.fu.branch.main_stage import BranchMainStage
17 from soc.fu.alu.pipe_data import ALUPipeSpec
18 from soc.fu.alu.alu_input_record import CompALUOpSubset
19 from soc.decoder.power_enums import MicrOp
20 import unittest
21
22
23 # This defines a module to drive the device under test and assert
24 # properties about its outputs
25 class Driver(Elaboratable):
26 def __init__(self):
27 # inputs and outputs
28 pass
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33
34 rec = CompALUOpSubset()
35 recwidth = 0
36 # Setup random inputs for dut.op
37 for p in rec.ports():
38 width = p.width
39 recwidth += width
40 comb += p.eq(AnyConst(width))
41
42 pspec = ALUPipeSpec(id_wid=2)
43 m.submodules.dut = dut = BranchMainStage(pspec)
44
45 comb += dut.i.ctx.op.eq(rec)
46
47 # Assert that op gets copied from the input to output
48 for rec_sig in rec.ports():
49 name = rec_sig.name
50 dut_sig = getattr(dut.o.ctx.op, name)
51 comb += Assert(dut_sig == rec_sig)
52
53 # Full width CR register. Will have bitfield extracted for
54 # feeding to branch unit
55 cr = Signal(32)
56 comb += cr.eq(AnyConst(32))
57 cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)])
58 cr_bit_arr = Array([cr[31-i] for i in range(32)])
59
60 cia, cr_in, fast1, fast2 = dut.i.cia, dut.i.cr, dut.i.fast1, dut.i.fast2
61 ctr = fast1
62 lr_o, nia_o = dut.o.lr, dut.o.nia
63
64 comb += [fast2.eq(AnyConst(64)),
65 ctr.eq(AnyConst(64)),
66 cia.eq(AnyConst(64))]
67
68 i_fields = dut.fields.FormI
69 b_fields = dut.fields.FormB
70 AA = i_fields.AA[0:-1]
71
72 # Handle CR bit selection
73 BI = b_fields.BI[0:-1]
74 bi = Signal(3, reset_less=True)
75 comb += bi.eq(BI[2:5])
76 comb += dut.i.cr.eq(cr_arr[bi])
77
78 # Handle branch out
79 BO = b_fields.BO[0:-1]
80 bo = Signal(BO.shape())
81 comb += bo.eq(BO)
82 cond_ok = Signal()
83
84 # Check CR according to BO
85 comb += cond_ok.eq(bo[4] | (cr_bit_arr[BI] == bo[3]))
86
87 # CTR decrement
88 ctr_next = Signal.like(ctr)
89 with m.If(~BO[2]):
90 comb += ctr_next.eq(ctr - 1)
91 with m.Else():
92 comb += ctr_next.eq(ctr)
93
94 # CTR combpare with 0
95 ctr_ok = Signal()
96 comb += ctr_ok.eq(BO[2] | ((ctr != 0) ^ BO[1]))
97
98 # Sorry, not bothering with 32 bit right now
99 comb += Assume(~rec.is_32bit)
100
101 with m.Switch(rec.insn_type):
102
103 #### b ####
104 with m.Case(MicrOp.OP_B):
105 # Extract target address
106 LI = i_fields.LI[0:-1]
107 imm = exts(LI, LI.shape().width, 64-2) * 4
108
109 # Assert that it always branches
110 comb += Assert(nia_o.ok == 1)
111
112 # Check absolute or relative branching
113 with m.If(AA):
114 comb += Assert(nia_o.data == imm)
115 with m.Else():
116 comb += Assert(nia_o.data == (cia + imm)[0:64])
117
118 # Make sure linking works
119 with m.If(rec.lk):
120 comb += Assert(lr_o.data == (cia + 4)[0:64])
121 comb += Assert(lr_o.ok == 1)
122 with m.Else():
123 comb += Assert(lr_o.ok == 0)
124
125 # Assert that ctr is not written to
126 comb += Assert(dut.o.ctr.ok == 0)
127
128 #### bc ####
129 with m.Case(MicrOp.OP_BC):
130 # Assert that branches are conditional
131 comb += Assert(nia_o.ok == (cond_ok & ctr_ok))
132
133 # extract target address
134 BD = b_fields.BD[0:-1]
135 imm = exts(BD, BD.shape().width, 64-2) * 4
136
137 # Check absolute or relative branching
138 with m.If(nia_o.ok):
139 with m.If(AA):
140 comb += Assert(nia_o.data == imm)
141 with m.Else():
142 comb += Assert(nia_o.data == (cia + imm)[0:64])
143 comb += Assert(lr_o.ok == rec.lk)
144 with m.If(rec.lk):
145 comb += Assert(lr_o.data == (cia + 4)[0:64])
146
147 # Check that CTR is decremented
148 with m.If(~BO[2]):
149 comb += Assert(dut.o.ctr.data == ctr_next)
150 comb += Assert(dut.o.ctr.ok == 1)
151 with m.Else():
152 comb += Assert(dut.o.ctr.ok == 0)
153 #### bctar/bcctr/bclr ####
154 with m.Case(MicrOp.OP_BCREG):
155 # assert that the condition is good
156 comb += Assert(nia_o.ok == (cond_ok & ctr_ok))
157
158 with m.If(nia_o.ok):
159 # make sure we branch to the spr input
160 comb += Assert(nia_o.data == fast1)
161
162 # make sure branch+link works
163 comb += Assert(lr_o.ok == rec.lk)
164 with m.If(rec.lk):
165 comb += Assert(lr_o.data == (cia + 4)[0:64])
166
167 # Check that CTR is decremented
168 with m.If(~BO[2]):
169 comb += Assert(dut.o.ctr.data == ctr_next)
170 comb += Assert(dut.o.ctr.ok == 1)
171 with m.Else():
172 comb += Assert(dut.o.ctr.ok == 0)
173
174 return m
175
176
177 class LogicalTestCase(FHDLTestCase):
178 def test_formal(self):
179 module = Driver()
180 self.assertFormal(module, mode="bmc", depth=2)
181 def test_ilang(self):
182 dut = Driver()
183 vl = rtlil.convert(dut, ports=[])
184 with open("main_stage.il", "w") as f:
185 f.write(vl)
186
187
188 if __name__ == '__main__':
189 unittest.main()