Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6
7 run tests with:
8 pip install pytest
9 pip install pytest-xdist
10 pytest -n auto src/soc/fu/shift_rot/formal/proof_main_stage.py
11 because that tells pytest to run the tests in parallel, it will take a few
12 minutes instead of an hour.
13 """
14
15 import unittest
16 import enum
17 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
18 signed, Const, unsigned)
19 from nmigen.asserts import Assert, AnyConst, Assume
20 from nmutil.formaltest import FHDLTestCase
21 from nmutil.sim_util import do_sim
22 from nmigen.sim import Delay
23
24 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
25 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
26 from openpower.decoder.power_enums import MicrOp
27
28
29 @enum.unique
30 class TstOp(enum.Enum):
31 """ops we're testing, the idea is if we run a separate formal proof for
32 each instruction, we end up covering them all and each runs much faster,
33 also the formal proofs can be run in parallel."""
34 SHL = MicrOp.OP_SHL
35 SHR = MicrOp.OP_SHR
36 RLC32 = MicrOp.OP_RLC, 32
37 RLC64 = MicrOp.OP_RLC, 64
38 RLCL = MicrOp.OP_RLCL
39 RLCR = MicrOp.OP_RLCR
40 EXTSWSLI = MicrOp.OP_EXTSWSLI
41 TERNLOG = MicrOp.OP_TERNLOG
42 # grev removed -- leaving code for later use in grevlut
43 # GREV32 = MicrOp.OP_GREV, 32
44 # GREV64 = MicrOp.OP_GREV, 64
45
46 @property
47 def op(self):
48 if isinstance(self.value, tuple):
49 return self.value[0]
50 return self.value
51
52
53 def eq_any_const(sig: Signal):
54 return sig.eq(AnyConst(sig.shape(), src_loc_at=1))
55
56
57 class Mask(Elaboratable):
58 # copied from qemu's mask fn:
59 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
60 def __init__(self):
61 self.start = Signal(6)
62 self.end = Signal(6)
63 self.out = Signal(64)
64
65 def elaborate(self, platform):
66 m = Module()
67 max_val = Const(~0, unsigned(64))
68 max_bit = 63
69 with m.If(self.start == 0):
70 m.d.comb += self.out.eq(max_val << (max_bit - self.end))
71 with m.Elif(self.end == max_bit):
72 m.d.comb += self.out.eq(max_val >> self.start)
73 with m.Else():
74 ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1)
75 m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret))
76 return m
77
78
79 class TstMask(unittest.TestCase):
80 def test_mask(self):
81 dut = Mask()
82
83 def case(start, end, expected):
84 with self.subTest(start=start, end=end):
85 yield dut.start.eq(start)
86 yield dut.end.eq(end)
87 yield Delay(1e-6)
88 out = yield dut.out
89 with self.subTest(out=hex(out), expected=hex(expected)):
90 self.assertEqual(expected, out)
91
92 def process():
93 for start in range(64):
94 for end in range(64):
95 expected = 0
96 if start > end:
97 for i in range(start, 64):
98 expected |= 1 << (63 - i)
99 for i in range(0, end + 1):
100 expected |= 1 << (63 - i)
101 else:
102 for i in range(start, end + 1):
103 expected |= 1 << (63 - i)
104 yield from case(start, end, expected)
105 with do_sim(self, dut, [dut.start, dut.end, dut.out]) as sim:
106 sim.add_process(process)
107 sim.run()
108
109
110 def rotl64(v, amt):
111 v |= Const(0, 64) # convert to value at least 64-bits wide
112 amt |= Const(0, 6) # convert to value at least 6-bits wide
113 return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64]
114
115
116 def rotl32(v, amt):
117 v |= Const(0, 32) # convert to value at least 32-bits wide
118 return rotl64(Cat(v[:32], v[:32]), amt)
119
120
121 # This defines a module to drive the device under test and assert
122 # properties about its outputs
123 class Driver(Elaboratable):
124 def __init__(self, which):
125 assert isinstance(which, TstOp) or which is None
126 self.which = which
127
128 def elaborate(self, platform):
129 m = Module()
130 comb = m.d.comb
131
132 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
133 pspec.draft_bitmanip = True
134 m.submodules.dut = dut = ShiftRotMainStage(pspec)
135
136 # Set inputs to formal variables
137 comb += [
138 eq_any_const(dut.i.ctx.op.insn_type),
139 eq_any_const(dut.i.ctx.op.fn_unit),
140 eq_any_const(dut.i.ctx.op.imm_data.data),
141 eq_any_const(dut.i.ctx.op.imm_data.ok),
142 eq_any_const(dut.i.ctx.op.rc.rc),
143 eq_any_const(dut.i.ctx.op.rc.ok),
144 eq_any_const(dut.i.ctx.op.oe.oe),
145 eq_any_const(dut.i.ctx.op.oe.ok),
146 eq_any_const(dut.i.ctx.op.write_cr0),
147 eq_any_const(dut.i.ctx.op.input_carry),
148 eq_any_const(dut.i.ctx.op.output_carry),
149 eq_any_const(dut.i.ctx.op.input_cr),
150 eq_any_const(dut.i.ctx.op.is_32bit),
151 eq_any_const(dut.i.ctx.op.is_signed),
152 eq_any_const(dut.i.ctx.op.insn),
153 eq_any_const(dut.i.xer_ca),
154 eq_any_const(dut.i.ra),
155 eq_any_const(dut.i.rb),
156 eq_any_const(dut.i.rc),
157 ]
158
159 # check that the operation (op) is passed through (and muxid)
160 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
161 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
162
163 if self.which is None:
164 for i in TstOp:
165 comb += Assume(dut.i.ctx.op.insn_type != i.op)
166 comb += Assert(~dut.o.o.ok)
167 else:
168 # we're only checking a particular operation:
169 comb += Assume(dut.i.ctx.op.insn_type == self.which.op)
170 comb += Assert(dut.o.o.ok)
171
172 # dispatch to check fn for each op
173 getattr(self, f"_check_{self.which.name.lower()}")(m, dut)
174
175 return m
176
177 def _check_shl(self, m, dut):
178 m.d.comb += Assume(dut.i.ra == 0)
179 expected = Signal(64)
180 with m.If(dut.i.ctx.op.is_32bit):
181 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32])
182 with m.Else():
183 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64])
184 m.d.comb += Assert(dut.o.o.data == expected)
185 m.d.comb += Assert(dut.o.xer_ca.data == 0)
186
187 def _check_shr(self, m, dut):
188 m.d.comb += Assume(dut.i.ra == 0)
189 expected = Signal(64)
190 carry = Signal()
191 shift_in_s = Signal(signed(128))
192 shift_roundtrip = Signal(signed(128))
193 shift_in_u = Signal(128)
194 shift_amt = Signal(7)
195 with m.If(dut.i.ctx.op.is_32bit):
196 m.d.comb += [
197 shift_amt.eq(dut.i.rb[:6]),
198 shift_in_s.eq(dut.i.rs[:32].as_signed()),
199 shift_in_u.eq(dut.i.rs[:32]),
200 ]
201 with m.Else():
202 m.d.comb += [
203 shift_amt.eq(dut.i.rb[:7]),
204 shift_in_s.eq(dut.i.rs.as_signed()),
205 shift_in_u.eq(dut.i.rs),
206 ]
207
208 with m.If(dut.i.ctx.op.is_signed):
209 m.d.comb += [
210 expected.eq(shift_in_s >> shift_amt),
211 shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt),
212 carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)),
213 ]
214 with m.Else():
215 m.d.comb += [
216 expected.eq(shift_in_u >> shift_amt),
217 carry.eq(0),
218 ]
219 m.d.comb += Assert(dut.o.o.data == expected)
220 m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
221
222 def _check_rlc32(self, m, dut):
223 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
224 # rlwimi, rlwinm, and rlwnm
225
226 m.submodules.mask = mask = Mask()
227 expected = Signal(64)
228 rot = Signal(64)
229 m.d.comb += rot.eq(rotl32(dut.i.rs[:32], dut.i.rb[:5]))
230 m.d.comb += mask.start.eq(dut.fields.FormM.MB[:] + 32)
231 m.d.comb += mask.end.eq(dut.fields.FormM.ME[:] + 32)
232
233 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
234 # the expression turns into a no-op
235 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
236 m.d.comb += Assert(dut.o.o.data == expected)
237 m.d.comb += Assert(dut.o.xer_ca.data == 0)
238
239 def _check_rlc64(self, m, dut):
240 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
241 # rldic and rldimi
242
243 # `rb` is always a 6-bit immediate
244 m.d.comb += Assume(dut.i.rb[6:] == 0)
245
246 m.submodules.mask = mask = Mask()
247 expected = Signal(64)
248 rot = Signal(64)
249 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
250 mb = dut.fields.FormMD.mb[:]
251 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
252 m.d.comb += mask.end.eq(63 - dut.i.rb[:6])
253
254 # for rldic, ra is guaranteed to be 0, so that part of
255 # the expression turns into a no-op
256 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
257 m.d.comb += Assert(dut.o.o.data == expected)
258 m.d.comb += Assert(dut.o.xer_ca.data == 0)
259
260 def _check_rlcl(self, m, dut):
261 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
262 # rldicl and rldcl
263
264 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
265 m.d.comb += Assume(dut.i.ra == 0)
266
267 m.submodules.mask = mask = Mask()
268 m.d.comb += mask.end.eq(63)
269 mb = dut.fields.FormMD.mb[:]
270 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
271
272 rot = Signal(64)
273 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
274
275 expected = Signal(64)
276 m.d.comb += expected.eq(rot & mask.out)
277
278 m.d.comb += Assert(dut.o.o.data == expected)
279 m.d.comb += Assert(dut.o.xer_ca.data == 0)
280
281 def _check_rlcr(self, m, dut):
282 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
283 # rldicr and rldcr
284
285 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
286 m.d.comb += Assume(dut.i.ra == 0)
287
288 m.submodules.mask = mask = Mask()
289 m.d.comb += mask.start.eq(0)
290 me = dut.fields.FormMD.me[:]
291 m.d.comb += mask.end.eq(Cat(me[1:6], me[0]))
292
293 rot = Signal(64)
294 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
295
296 expected = Signal(64)
297 m.d.comb += expected.eq(rot & mask.out)
298
299 m.d.comb += Assert(dut.o.o.data == expected)
300 m.d.comb += Assert(dut.o.xer_ca.data == 0)
301
302 def _check_extswsli(self, m, dut):
303 m.d.comb += Assume(dut.i.ra == 0)
304 m.d.comb += Assume(dut.i.rb[6:] == 0)
305 m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit
306 expected = Signal(64)
307 m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6]))
308 m.d.comb += Assert(dut.o.o.data == expected)
309 m.d.comb += Assert(dut.o.xer_ca.data == 0)
310
311 def _check_ternlog(self, m, dut):
312 lut = dut.fields.FormTLI.TLI[:]
313 for i in range(64):
314 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
315 for j in range(8):
316 with m.If(j == idx):
317 m.d.comb += Assert(dut.o.o.data[i] == lut[j])
318 m.d.comb += Assert(dut.o.xer_ca.data == 0)
319
320 # grev removed -- leaving code for later use in grevlut
321 def _check_grev32(self, m, dut):
322 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
323 # assert zero-extended
324 m.d.comb += Assert(dut.o.o.data[32:] == 0)
325 i = Signal(5)
326 m.d.comb += eq_any_const(i)
327 idx = dut.i.rb[0: 5] ^ i
328 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
329 m.d.comb += Assert(dut.o.xer_ca.data == 0)
330
331 # grev removed -- leaving code for later use in grevlut
332 def _check_grev64(self, m, dut):
333 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
334 i = Signal(6)
335 m.d.comb += eq_any_const(i)
336 idx = dut.i.rb[0: 6] ^ i
337 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
338 m.d.comb += Assert(dut.o.xer_ca.data == 0)
339
340
341 class ALUTestCase(FHDLTestCase):
342 def run_it(self, which):
343 module = Driver(which)
344 self.assertFormal(module, mode="bmc", depth=2)
345 self.assertFormal(module, mode="cover", depth=2)
346
347 def test_none(self):
348 self.run_it(None)
349
350 def test_shl(self):
351 self.run_it(TstOp.SHL)
352
353 def test_shr(self):
354 self.run_it(TstOp.SHR)
355
356 def test_rlc32(self):
357 self.run_it(TstOp.RLC32)
358
359 def test_rlc64(self):
360 self.run_it(TstOp.RLC64)
361
362 def test_rlcl(self):
363 self.run_it(TstOp.RLCL)
364
365 def test_rlcr(self):
366 self.run_it(TstOp.RLCR)
367
368 def test_extswsli(self):
369 self.run_it(TstOp.EXTSWSLI)
370
371 def test_ternlog(self):
372 self.run_it(TstOp.TERNLOG)
373
374 @unittest.skip("grev removed -- leaving code for later use in grevlut")
375 def test_grev32(self):
376 self.run_it(TstOp.GREV32)
377
378 @unittest.skip("grev removed -- leaving code for later use in grevlut")
379 def test_grev64(self):
380 self.run_it(TstOp.GREV64)
381
382
383 # check that all test cases are covered
384 for i in TstOp:
385 assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}"))
386
387
388 if __name__ == '__main__':
389 unittest.main()