Address code review comments
[soc.git] / src / soc / fu / trap / formal / proof_main_stage.py
1 # Proof of correctness for trap pipeline, main stage
2
3
4 """
5 Links:
6 * https://bugs.libre-soc.org/show_bug.cgi?id=421
7 * https://libre-soc.org/openpower/isa/fixedtrap/
8 * https://libre-soc.org/openpower/isa/sprset/
9 * https://libre-soc.org/openpower/isa/system/
10 """
11
12
13 import unittest
14
15 from nmigen import Cat, Const, Elaboratable, Module, Signal, signed
16 from nmigen.asserts import Assert, AnyConst
17 from nmigen.cli import rtlil
18
19 from nmutil.extend import exts
20 from nmutil.formaltest import FHDLTestCase
21
22 from soc.consts import MSR, MSRb, PI, TT
23
24 from soc.decoder.power_enums import MicrOp
25
26 from soc.fu.trap.main_stage import TrapMainStage
27 from soc.fu.trap.pipe_data import TrapPipeSpec
28 from soc.fu.trap.trap_input_record import CompTrapOpSubset
29
30
31 def field(r, start, end=None):
32 """Answers with a subfield of the signal r ("register"), where
33 the start and end bits use IBM conventions. start < end, if
34 end is provided. The range specified is inclusive on both ends.
35 """
36 if end is None:
37 return r[63 - start]
38 if start >= end:
39 raise ValueError(
40 "start ({}) must be less than end ({})".format(start, end)
41 )
42 start = 63 - start
43 end = 63 - end
44 return r[end:start+1]
45
46
47 class Driver(Elaboratable):
48 """
49 """
50
51 def elaborate(self, platform):
52 m = Module()
53 comb = m.d.comb
54
55 rec = CompTrapOpSubset()
56 pspec = TrapPipeSpec(id_wid=2)
57
58 m.submodules.dut = dut = TrapMainStage(pspec)
59
60 # frequently used aliases
61 op = dut.i.ctx.op
62 msr_o, msr_i = dut.o.msr, op.msr
63 srr0_o, srr0_i = dut.o.srr0, dut.i.srr0
64 srr1_o, srr1_i = dut.o.srr1, dut.i.srr1
65 nia_o = dut.o.nia
66
67 comb += op.eq(rec)
68
69 d_fields = dut.fields.FormD
70 sc_fields = dut.fields.FormSC
71
72 # start of properties
73 with m.Switch(op.insn_type):
74
75 ###############
76 # TDI/TWI/TD/TW. v3.0B p90-91
77 ###############
78 with m.Case(MicrOp.OP_TRAP):
79 to = Signal(len(d_fields.TO))
80 comb += to.eq(d_fields.TO[0:-1])
81
82 a_i = Signal(64)
83 b_i = Signal(64)
84 comb += a_i.eq(dut.i.a)
85 comb += b_i.eq(dut.i.b)
86
87 a_s = Signal(signed(64), reset_less=True)
88 b_s = Signal(signed(64), reset_less=True)
89 a = Signal(64, reset_less=True)
90 b = Signal(64, reset_less=True)
91
92 with m.If(op.is_32bit):
93 comb += a_s.eq(exts(a_i, 32, 64))
94 comb += b_s.eq(exts(b_i, 32, 64))
95 comb += a.eq(a_i[0:32])
96 comb += b.eq(b_i[0:32])
97 with m.Else():
98 comb += a_s.eq(a_i)
99 comb += b_s.eq(b_i)
100 comb += a.eq(a_i)
101 comb += b.eq(b_i)
102
103 lt_s = Signal(reset_less=True)
104 gt_s = Signal(reset_less=True)
105 lt_u = Signal(reset_less=True)
106 gt_u = Signal(reset_less=True)
107 equal = Signal(reset_less=True)
108
109 comb += lt_s.eq(a_s < b_s)
110 comb += gt_s.eq(a_s > b_s)
111 comb += lt_u.eq(a < b)
112 comb += gt_u.eq(a > b)
113 comb += equal.eq(a == b)
114
115 trapbits = Signal(5, reset_less=True)
116 comb += trapbits.eq(Cat(gt_u, lt_u, equal, gt_s, lt_s))
117
118 take_trap = Signal()
119 traptype = op.traptype
120 comb += take_trap.eq(traptype.any() | (trapbits & to).any())
121
122 with m.If(take_trap):
123 expected_msr = Signal(len(msr_o.data))
124 comb += expected_msr.eq(op.msr)
125
126 comb += field(expected_msr, MSRb.IR).eq(0)
127 comb += field(expected_msr, MSRb.DR).eq(0)
128 comb += field(expected_msr, MSRb.FE0).eq(0)
129 comb += field(expected_msr, MSRb.FE1).eq(0)
130 comb += field(expected_msr, MSRb.EE).eq(0)
131 comb += field(expected_msr, MSRb.RI).eq(0)
132 comb += field(expected_msr, MSRb.SF).eq(1)
133 comb += field(expected_msr, MSRb.TM).eq(0)
134 comb += field(expected_msr, MSRb.VEC).eq(0)
135 comb += field(expected_msr, MSRb.VSX).eq(0)
136 comb += field(expected_msr, MSRb.PR).eq(0)
137 comb += field(expected_msr, MSRb.FP).eq(0)
138 comb += field(expected_msr, MSRb.PMM).eq(0)
139 comb += field(expected_msr, MSRb.TEs, MSRb.TEe).eq(0)
140 comb += field(expected_msr, MSRb.UND).eq(0)
141 comb += field(expected_msr, MSRb.LE).eq(1)
142
143 expected_srr1 = Signal(len(srr1_o.data))
144 comb += expected_srr1.eq(op.msr)
145
146 comb += expected_srr1[63-36:63-32].eq(0)
147 comb += expected_srr1[PI.TRAP].eq(traptype == 0)
148 comb += expected_srr1[PI.PRIV].eq(traptype[1])
149 comb += expected_srr1[PI.FP].eq(traptype[0])
150 comb += expected_srr1[PI.ADR].eq(traptype[3])
151 comb += expected_srr1[PI.ILLEG].eq(traptype[4])
152 comb += expected_srr1[PI.TM_BAD_THING].eq(0)
153
154 comb += [
155 Assert(msr_o.ok),
156 Assert(msr_o.data == expected_msr),
157 Assert(srr0_o.ok),
158 Assert(srr0_o.data == op.cia),
159 Assert(srr1_o.ok),
160 Assert(srr1_o.data == expected_srr1),
161 Assert(nia_o.ok),
162 Assert(nia_o.data == op.trapaddr << 4),
163 ]
164
165 #################
166 # SC. v3.0B p952
167 #################
168 with m.Case(MicrOp.OP_SC):
169 expected_msr = Signal(len(msr_o.data))
170 comb += expected_msr.eq(op.msr)
171 # Unless otherwise documented, these exceptions to the MSR bits
172 # are documented in Power ISA V3.0B, page 1063 or 1064.
173 # We are not supporting hypervisor or transactional semantics,
174 # so we skip enforcing those fields' properties.
175 comb += field(expected_msr, MSRb.IR).eq(0)
176 comb += field(expected_msr, MSRb.DR).eq(0)
177 comb += field(expected_msr, MSRb.FE0).eq(0)
178 comb += field(expected_msr, MSRb.FE1).eq(0)
179 comb += field(expected_msr, MSRb.EE).eq(0)
180 comb += field(expected_msr, MSRb.RI).eq(0)
181 comb += field(expected_msr, MSRb.SF).eq(1)
182 comb += field(expected_msr, MSRb.TM).eq(0)
183 comb += field(expected_msr, MSRb.VEC).eq(0)
184 comb += field(expected_msr, MSRb.VSX).eq(0)
185 comb += field(expected_msr, MSRb.PR).eq(0)
186 comb += field(expected_msr, MSRb.FP).eq(0)
187 comb += field(expected_msr, MSRb.PMM).eq(0)
188 comb += field(expected_msr, MSRb.TEs, MSRb.TEe).eq(0)
189 comb += field(expected_msr, MSRb.UND).eq(0)
190 comb += field(expected_msr, MSRb.LE).eq(1)
191
192 comb += [
193 Assert(dut.o.srr0.ok),
194 Assert(srr1_o.ok),
195 Assert(msr_o.ok),
196
197 Assert(dut.o.srr0.data == (op.cia + 4)[0:64]),
198 Assert(field(srr1_o, 33, 36) == 0),
199 Assert(field(srr1_o, 42, 47) == 0),
200 Assert(field(srr1_o, 0, 32) == field(msr_i, 0, 32)),
201 Assert(field(srr1_o, 37, 41) == field(msr_i, 37, 41)),
202 Assert(field(srr1_o, 48, 63) == field(msr_i, 48, 63)),
203
204 Assert(msr_o.data == expected_msr),
205 ]
206
207 ###################
208 # RFID. v3.0B p955
209 ###################
210 with m.Case(MicrOp.OP_RFID):
211 comb += [
212 Assert(msr_o.ok),
213 Assert(nia_o.ok),
214 ]
215
216 # if (MSR[29:31] != 0b010) | (SRR1[29:31] != 0b000) then
217 # MSR[29:31] <- SRR1[29:31]
218 with m.If((field(msr_i , 29, 31) != 0b010) |
219 (field(srr1_i, 29, 31) != 0b000)):
220 comb += Assert(field(msr_o.data, 29, 31) ==
221 field(srr1_i, 29, 31))
222 with m.Else():
223 comb += Assert(field(msr_o.data, 29, 31) ==
224 field(msr_i, 29, 31))
225
226 # check EE (48) IR (58), DR (59): PR (49) will over-ride
227 for bit in [48, 58, 59]:
228 comb += Assert(
229 field(msr_o, bit) ==
230 (field(srr1_i, bit) | field(srr1_i, 49))
231 )
232
233 # remaining bits: straight copy. don't know what these are:
234 # just trust the v3.0B spec is correct.
235 comb += [
236 Assert(field(msr_o, 0, 2) == field(srr1_i, 0, 2)),
237 Assert(field(msr_o, 4, 28) == field(srr1_i, 4, 28)),
238 Assert(field(msr_o, 32) == field(srr1_i, 32)),
239 Assert(field(msr_o, 37, 41) == field(srr1_i, 37, 41)),
240 Assert(field(msr_o, 49, 50) == field(srr1_i, 49, 50)),
241 Assert(field(msr_o, 52, 57) == field(srr1_i, 52, 57)),
242 Assert(field(msr_o, 60, 63) == field(srr1_i, 60, 63)),
243 ]
244
245 # check NIA against SRR0. 2 LSBs are set to zero (word-align)
246 comb += Assert(nia_o.data == Cat(Const(0, 2), dut.i.srr0[2:]))
247
248 comb += dut.i.ctx.matches(dut.o.ctx)
249
250 return m
251
252
253 class TrapMainStageTestCase(FHDLTestCase):
254 def test_formal(self):
255 self.assertFormal(Driver(), mode="bmc", depth=10)
256 self.assertFormal(Driver(), mode="cover", depth=10)
257
258 def test_ilang(self):
259 vl = rtlil.convert(Driver(), ports=[])
260 with open("trap_main_stage.il", "w") as f:
261 f.write(vl)
262
263
264 if __name__ == '__main__':
265 unittest.main()
266