add formal proof for shift/rot o.ok
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6 """
7
8 import unittest
9 import enum
10 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
11 signed, Const, unsigned)
12 from nmigen.asserts import Assert, AnyConst, Assume
13 from nmutil.formaltest import FHDLTestCase
14 from nmutil.sim_util import do_sim
15 from nmigen.sim import Delay
16
17 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
18 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
19 from openpower.decoder.power_enums import MicrOp
20
21
22 @enum.unique
23 class TstOp(enum.Enum):
24 """ops we're testing, the idea is if we run a separate formal proof for
25 each instruction, we end up covering them all and each runs much faster,
26 also the formal proofs can be run in parallel."""
27 SHL = MicrOp.OP_SHL
28 SHR = MicrOp.OP_SHR
29 RLC32 = MicrOp.OP_RLC, 32
30 RLC64 = MicrOp.OP_RLC, 64
31 RLCL = MicrOp.OP_RLCL
32 RLCR = MicrOp.OP_RLCR
33 EXTSWSLI = MicrOp.OP_EXTSWSLI
34 TERNLOG = MicrOp.OP_TERNLOG
35 GREV32 = MicrOp.OP_GREV, 32
36 GREV64 = MicrOp.OP_GREV, 64
37
38 @property
39 def op(self):
40 if isinstance(self.value, tuple):
41 return self.value[0]
42 return self.value
43
44
45 def eq_any_const(sig: Signal):
46 return sig.eq(AnyConst(sig.shape(), src_loc_at=1))
47
48
49 class Mask(Elaboratable):
50 # copied from qemu's mask fn:
51 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
52 def __init__(self):
53 self.start = Signal(6)
54 self.end = Signal(6)
55 self.out = Signal(64)
56
57 def elaborate(self, platform):
58 m = Module()
59 max_val = Const(~0, unsigned(64))
60 max_bit = 63
61 with m.If(self.start == 0):
62 m.d.comb += self.out.eq(max_val << (max_bit - self.end))
63 with m.Elif(self.end == max_bit):
64 m.d.comb += self.out.eq(max_val >> self.start)
65 with m.Else():
66 ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1)
67 m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret))
68 return m
69
70
71 class TstMask(unittest.TestCase):
72 def test_mask(self):
73 dut = Mask()
74
75 def case(start, end, expected):
76 with self.subTest(start=start, end=end):
77 yield dut.start.eq(start)
78 yield dut.end.eq(end)
79 yield Delay(1e-6)
80 out = yield dut.out
81 with self.subTest(out=hex(out), expected=hex(expected)):
82 self.assertEqual(expected, out)
83
84 def process():
85 for start in range(64):
86 for end in range(64):
87 expected = 0
88 if start > end:
89 for i in range(start, 64):
90 expected |= 1 << (63 - i)
91 for i in range(0, end + 1):
92 expected |= 1 << (63 - i)
93 else:
94 for i in range(start, end + 1):
95 expected |= 1 << (63 - i)
96 yield from case(start, end, expected)
97 with do_sim(self, dut, [dut.start, dut.end, dut.out]) as sim:
98 sim.add_process(process)
99 sim.run()
100
101
102 def rotl64(v, amt):
103 v |= Const(0, 64) # convert to value at least 64-bits wide
104 amt |= Const(0, 6) # convert to value at least 6-bits wide
105 return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64]
106
107
108 def rotl32(v, amt):
109 v |= Const(0, 32) # convert to value at least 32-bits wide
110 return rotl64(Cat(v[:32], v[:32]), amt)
111
112
113 # This defines a module to drive the device under test and assert
114 # properties about its outputs
115 class Driver(Elaboratable):
116 def __init__(self, which):
117 assert isinstance(which, TstOp) or which is None
118 self.which = which
119
120 def elaborate(self, platform):
121 m = Module()
122 comb = m.d.comb
123
124 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
125 pspec.draft_bitmanip = True
126 m.submodules.dut = dut = ShiftRotMainStage(pspec)
127
128 # Set inputs to formal variables
129 comb += [
130 eq_any_const(dut.i.ctx.op.insn_type),
131 eq_any_const(dut.i.ctx.op.fn_unit),
132 eq_any_const(dut.i.ctx.op.imm_data.data),
133 eq_any_const(dut.i.ctx.op.imm_data.ok),
134 eq_any_const(dut.i.ctx.op.rc.rc),
135 eq_any_const(dut.i.ctx.op.rc.ok),
136 eq_any_const(dut.i.ctx.op.oe.oe),
137 eq_any_const(dut.i.ctx.op.oe.ok),
138 eq_any_const(dut.i.ctx.op.write_cr0),
139 eq_any_const(dut.i.ctx.op.input_carry),
140 eq_any_const(dut.i.ctx.op.output_carry),
141 eq_any_const(dut.i.ctx.op.input_cr),
142 eq_any_const(dut.i.ctx.op.is_32bit),
143 eq_any_const(dut.i.ctx.op.is_signed),
144 eq_any_const(dut.i.ctx.op.insn),
145 eq_any_const(dut.i.xer_ca),
146 eq_any_const(dut.i.ra),
147 eq_any_const(dut.i.rb),
148 eq_any_const(dut.i.rc),
149 ]
150
151 # check that the operation (op) is passed through (and muxid)
152 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
153 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
154
155 if self.which is None:
156 for i in TstOp:
157 comb += Assume(dut.i.ctx.op.insn_type != i.op)
158 comb += Assert(~dut.o.o.ok)
159 else:
160 # we're only checking a particular operation:
161 comb += Assume(dut.i.ctx.op.insn_type == self.which.op)
162 comb += Assert(dut.o.o.ok)
163
164 # dispatch to check fn for each op
165 getattr(self, f"_check_{self.which.name.lower()}")(m, dut)
166
167 return m
168
169 def _check_shl(self, m, dut):
170 m.d.comb += Assume(dut.i.ra == 0)
171 expected = Signal(64)
172 with m.If(dut.i.ctx.op.is_32bit):
173 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32])
174 with m.Else():
175 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64])
176 m.d.comb += Assert(dut.o.o.data == expected)
177 m.d.comb += Assert(dut.o.xer_ca.data == 0)
178
179 def _check_shr(self, m, dut):
180 m.d.comb += Assume(dut.i.ra == 0)
181 expected = Signal(64)
182 carry = Signal()
183 shift_in_s = Signal(signed(128))
184 shift_roundtrip = Signal(signed(128))
185 shift_in_u = Signal(128)
186 shift_amt = Signal(7)
187 with m.If(dut.i.ctx.op.is_32bit):
188 m.d.comb += [
189 shift_amt.eq(dut.i.rb[:6]),
190 shift_in_s.eq(dut.i.rs[:32].as_signed()),
191 shift_in_u.eq(dut.i.rs[:32]),
192 ]
193 with m.Else():
194 m.d.comb += [
195 shift_amt.eq(dut.i.rb[:7]),
196 shift_in_s.eq(dut.i.rs.as_signed()),
197 shift_in_u.eq(dut.i.rs),
198 ]
199
200 with m.If(dut.i.ctx.op.is_signed):
201 m.d.comb += [
202 expected.eq(shift_in_s >> shift_amt),
203 shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt),
204 carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)),
205 ]
206 with m.Else():
207 m.d.comb += [
208 expected.eq(shift_in_u >> shift_amt),
209 carry.eq(0),
210 ]
211 m.d.comb += Assert(dut.o.o.data == expected)
212 m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
213
214 def _check_rlc32(self, m, dut):
215 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
216 # rlwimi, rlwinm, and rlwnm
217
218 m.submodules.mask = mask = Mask()
219 expected = Signal(64)
220 rot = Signal(64)
221 m.d.comb += rot.eq(rotl32(dut.i.rs[:32], dut.i.rb[:5]))
222 m.d.comb += mask.start.eq(dut.fields.FormM.MB[:] + 32)
223 m.d.comb += mask.end.eq(dut.fields.FormM.ME[:] + 32)
224
225 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
226 # the expression turns into a no-op
227 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
228 m.d.comb += Assert(dut.o.o.data == expected)
229 m.d.comb += Assert(dut.o.xer_ca.data == 0)
230
231 def _check_rlc64(self, m, dut):
232 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
233 # rldic and rldimi
234
235 # `rb` is always a 6-bit immediate
236 m.d.comb += Assume(dut.i.rb[6:] == 0)
237
238 m.submodules.mask = mask = Mask()
239 expected = Signal(64)
240 rot = Signal(64)
241 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
242 mb = dut.fields.FormMD.mb[:]
243 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
244 m.d.comb += mask.end.eq(63 - dut.i.rb[:6])
245
246 # for rldic, ra is guaranteed to be 0, so that part of
247 # the expression turns into a no-op
248 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
249 m.d.comb += Assert(dut.o.o.data == expected)
250 m.d.comb += Assert(dut.o.xer_ca.data == 0)
251
252 def _check_rlcl(self, m, dut):
253 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
254 # rldicl and rldcl
255
256 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
257 m.d.comb += Assume(dut.i.ra == 0)
258
259 m.submodules.mask = mask = Mask()
260 m.d.comb += mask.end.eq(63)
261 mb = dut.fields.FormMD.mb[:]
262 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
263
264 rot = Signal(64)
265 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
266
267 expected = Signal(64)
268 m.d.comb += expected.eq(rot & mask.out)
269
270 m.d.comb += Assert(dut.o.o.data == expected)
271 m.d.comb += Assert(dut.o.xer_ca.data == 0)
272
273 def _check_rlcr(self, m, dut):
274 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
275 # rldicr and rldcr
276
277 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
278 m.d.comb += Assume(dut.i.ra == 0)
279
280 m.submodules.mask = mask = Mask()
281 m.d.comb += mask.start.eq(0)
282 me = dut.fields.FormMD.me[:]
283 m.d.comb += mask.end.eq(Cat(me[1:6], me[0]))
284
285 rot = Signal(64)
286 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
287
288 expected = Signal(64)
289 m.d.comb += expected.eq(rot & mask.out)
290
291 m.d.comb += Assert(dut.o.o.data == expected)
292 m.d.comb += Assert(dut.o.xer_ca.data == 0)
293
294 def _check_extswsli(self, m, dut):
295 m.d.comb += Assume(dut.i.ra == 0)
296 m.d.comb += Assume(dut.i.rb[6:] == 0)
297 m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit
298 expected = Signal(64)
299 m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6]))
300 m.d.comb += Assert(dut.o.o.data == expected)
301 m.d.comb += Assert(dut.o.xer_ca.data == 0)
302
303 def _check_ternlog(self, m, dut):
304 lut = dut.fields.FormTLI.TLI[:]
305 for i in range(64):
306 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
307 for j in range(8):
308 with m.If(j == idx):
309 m.d.comb += Assert(dut.o.o.data[i] == lut[j])
310 m.d.comb += Assert(dut.o.xer_ca.data == 0)
311
312 def _check_grev32(self, m, dut):
313 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
314 # assert zero-extended
315 m.d.comb += Assert(dut.o.o.data[32:] == 0)
316 i = Signal(5)
317 m.d.comb += eq_any_const(i)
318 idx = dut.i.rb[0: 5] ^ i
319 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
320 m.d.comb += Assert(dut.o.xer_ca.data == 0)
321
322 def _check_grev64(self, m, dut):
323 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
324 i = Signal(6)
325 m.d.comb += eq_any_const(i)
326 idx = dut.i.rb[0: 6] ^ i
327 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
328 m.d.comb += Assert(dut.o.xer_ca.data == 0)
329
330
331 class ALUTestCase(FHDLTestCase):
332 def run_it(self, which):
333 module = Driver(which)
334 self.assertFormal(module, mode="bmc", depth=2)
335 self.assertFormal(module, mode="cover", depth=2)
336
337 def test_none(self):
338 self.run_it(None)
339
340 def test_shl(self):
341 self.run_it(TstOp.SHL)
342
343 def test_shr(self):
344 self.run_it(TstOp.SHR)
345
346 def test_rlc32(self):
347 self.run_it(TstOp.RLC32)
348
349 def test_rlc64(self):
350 self.run_it(TstOp.RLC64)
351
352 def test_rlcl(self):
353 self.run_it(TstOp.RLCL)
354
355 def test_rlcr(self):
356 self.run_it(TstOp.RLCR)
357
358 def test_extswsli(self):
359 self.run_it(TstOp.EXTSWSLI)
360
361 def test_ternlog(self):
362 self.run_it(TstOp.TERNLOG)
363
364 def test_grev32(self):
365 self.run_it(TstOp.GREV32)
366
367 def test_grev64(self):
368 self.run_it(TstOp.GREV64)
369
370
371 # check that all test cases are covered
372 for i in TstOp:
373 assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}"))
374
375
376 if __name__ == '__main__':
377 unittest.main()