be0c4b169fd94a79795c5c7781ba8e66be8b7566
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
9 from shutil
import which
10 from nmigen
import (Module
, Signal
, Elaboratable
, Mux
, Cat
, Repl
,
11 signed
, Array
, Const
, Value
)
12 from nmigen
.asserts
import Assert
, AnyConst
, Assume
, Cover
13 from nmutil
.formaltest
import FHDLTestCase
14 from nmigen
.cli
import rtlil
16 from soc
.fu
.shift_rot
.main_stage
import ShiftRotMainStage
17 from soc
.fu
.shift_rot
.rotator
import right_mask
, left_mask
18 from soc
.fu
.shift_rot
.pipe_data
import ShiftRotPipeSpec
19 from soc
.fu
.shift_rot
.sr_input_record
import CompSROpSubset
20 from openpower
.decoder
.power_enums
import MicrOp
21 from openpower
.consts
import field
24 from nmutil
.extend
import exts
28 class TstOp(enum
.Enum
):
29 """ops we're testing, the idea is if we run a separate formal proof for
30 each instruction, we end up covering them all and each runs much faster,
31 also the formal proofs can be run in parallel."""
37 EXTSWSLI
= MicrOp
.OP_EXTSWSLI
38 TERNLOG
= MicrOp
.OP_TERNLOG
39 GREV32
= MicrOp
.OP_GREV
, 32
40 GREV64
= MicrOp
.OP_GREV
, 64
44 if isinstance(self
.value
, tuple):
49 def eq_any_const(sig
: Signal
):
50 return sig
.eq(AnyConst(sig
.shape(), src_loc_at
=1))
53 class Mask(Elaboratable
):
54 # copied from qemu's mask fn:
55 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
57 self
.start
= Signal(6)
61 def elaborate(self
, platform
):
63 max_val
= Const(~
0, 64)
65 with m
.If(self
.start
== 0):
66 m
.d
.comb
+= self
.out
.eq(max_val
<< (max_bit
- self
.end
))
67 with m
.Elif(self
.end
== max_bit
):
68 m
.d
.comb
+= self
.out
.eq(max_val
>> self
.start
)
70 ret
= (max_val
>> self
.start
) ^
((max_val
>> self
.end
) >> 1)
71 m
.d
.comb
+= self
.out
.eq(Mux(self
.start
> self
.end
, ~ret
, ret
))
76 v |
= Const(0, 64) # convert to value at least 64-bits wide
77 amt |
= Const(0, 6) # convert to value at least 6-bits wide
78 return (Cat(v
[:64], v
[:64]) >> (64 - amt
[:6]))[:64]
82 v |
= Const(0, 32) # convert to value at least 32-bits wide
83 return rotl64(Cat(v
[:32], v
[:32]), amt
)
86 # This defines a module to drive the device under test and assert
87 # properties about its outputs
88 class Driver(Elaboratable
):
89 def __init__(self
, which
):
90 assert isinstance(which
, TstOp
)
93 def elaborate(self
, platform
):
97 pspec
= ShiftRotPipeSpec(id_wid
=2, parent_pspec
=None)
98 pspec
.draft_bitmanip
= True
99 m
.submodules
.dut
= dut
= ShiftRotMainStage(pspec
)
101 # Set inputs to formal variables
103 eq_any_const(dut
.i
.ctx
.op
.insn_type
),
104 eq_any_const(dut
.i
.ctx
.op
.fn_unit
),
105 eq_any_const(dut
.i
.ctx
.op
.imm_data
.data
),
106 eq_any_const(dut
.i
.ctx
.op
.imm_data
.ok
),
107 eq_any_const(dut
.i
.ctx
.op
.rc
.rc
),
108 eq_any_const(dut
.i
.ctx
.op
.rc
.ok
),
109 eq_any_const(dut
.i
.ctx
.op
.oe
.oe
),
110 eq_any_const(dut
.i
.ctx
.op
.oe
.ok
),
111 eq_any_const(dut
.i
.ctx
.op
.write_cr0
),
112 eq_any_const(dut
.i
.ctx
.op
.input_carry
),
113 eq_any_const(dut
.i
.ctx
.op
.output_carry
),
114 eq_any_const(dut
.i
.ctx
.op
.input_cr
),
115 eq_any_const(dut
.i
.ctx
.op
.is_32bit
),
116 eq_any_const(dut
.i
.ctx
.op
.is_signed
),
117 eq_any_const(dut
.i
.ctx
.op
.insn
),
118 eq_any_const(dut
.i
.xer_ca
),
119 eq_any_const(dut
.i
.ra
),
120 eq_any_const(dut
.i
.rb
),
121 eq_any_const(dut
.i
.rc
),
124 # check that the operation (op) is passed through (and muxid)
125 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
126 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
128 # we're only checking a particular operation:
129 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
== self
.which
.op
)
131 # dispatch to check fn for each op
132 getattr(self
, f
"_check_{self.which.name.lower()}")(m
, dut
)
136 # all following code in elaborate is kept for ease of reference, to be
137 # deleted once this proof is completed.
139 # convenience variables
140 rs
= dut
.i
.rs
# register to shift
141 b
= dut
.i
.rb
# register containing amount to shift by
142 ra
= dut
.i
.a
# source register if masking is to be done
143 carry_in
= dut
.i
.xer_ca
[0]
144 carry_in32
= dut
.i
.xer_ca
[1]
145 carry_out
= dut
.o
.xer_ca
147 print("fields", rec
.fields
)
148 itype
= rec
.insn_type
151 m_fields
= dut
.fields
.FormM
152 md_fields
= dut
.fields
.FormMD
154 # setup random inputs
155 comb
+= rs
.eq(AnyConst(64))
156 comb
+= ra
.eq(AnyConst(64))
157 comb
+= b
.eq(AnyConst(64))
158 comb
+= carry_in
.eq(AnyConst(1))
159 comb
+= carry_in32
.eq(AnyConst(1))
162 comb
+= dut
.i
.ctx
.op
.eq(rec
)
164 # check that the operation (op) is passed through (and muxid)
165 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
166 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
168 # signed and signed/32 versions of input rs
169 a_signed
= Signal(signed(64))
170 a_signed_32
= Signal(signed(32))
171 comb
+= a_signed
.eq(rs
)
172 comb
+= a_signed_32
.eq(rs
[0:32])
175 mb
= Signal(7, reset_less
=True)
176 ml
= Signal(64, reset_less
=True)
179 with m
.If((itype
== MicrOp
.OP_RLC
) |
(itype
== MicrOp
.OP_RLCL
)):
180 with m
.If(rec
.is_32bit
):
181 comb
+= mb
.eq(m_fields
.MB
[:])
183 comb
+= mb
.eq(md_fields
.mb
[:])
185 with m
.If(rec
.is_32bit
):
186 comb
+= mb
.eq(b
[0:6])
189 comb
+= ml
.eq(left_mask(m
, mb
))
192 me
= Signal(7, reset_less
=True)
193 mr
= Signal(64, reset_less
=True)
196 with m
.If((itype
== MicrOp
.OP_RLC
) |
(itype
== MicrOp
.OP_RLCR
)):
197 with m
.If(rec
.is_32bit
):
198 comb
+= me
.eq(m_fields
.ME
[:])
200 comb
+= me
.eq(md_fields
.me
[:])
202 with m
.If(rec
.is_32bit
):
203 comb
+= me
.eq(b
[0:6])
206 comb
+= mr
.eq(right_mask(m
, me
))
212 # main assertion of arithmetic operations
213 with m
.Switch(itype
):
215 # left-shift: 64/32-bit
216 with m
.Case(MicrOp
.OP_SHL
):
217 comb
+= Assume(ra
== 0)
218 with m
.If(rec
.is_32bit
):
219 comb
+= Assert(o
[0:32] == ((rs
<< b
[0:6]) & 0xffffffff))
220 comb
+= Assert(o
[32:64] == 0)
222 comb
+= Assert(o
== ((rs
<< b
[0:7]) & ((1 << 64)-1)))
224 # right-shift: 64/32-bit / signed
225 with m
.Case(MicrOp
.OP_SHR
):
226 comb
+= Assume(ra
== 0)
227 with m
.If(~rec
.is_signed
):
228 with m
.If(rec
.is_32bit
):
229 comb
+= Assert(o
[0:32] == (rs
[0:32] >> b
[0:6]))
230 comb
+= Assert(o
[32:64] == 0)
232 comb
+= Assert(o
== (rs
>> b
[0:7]))
234 with m
.If(rec
.is_32bit
):
235 comb
+= Assert(o
[0:32] == (a_signed_32
>> b
[0:6]))
236 comb
+= Assert(o
[32:64] == Repl(rs
[31], 32))
238 comb
+= Assert(o
== (a_signed
>> b
[0:7]))
240 # extswsli: 32/64-bit moded
241 with m
.Case(MicrOp
.OP_EXTSWSLI
):
242 comb
+= Assume(ra
== 0)
243 with m
.If(rec
.is_32bit
):
244 comb
+= Assert(o
[0:32] == ((rs
<< b
[0:6]) & 0xffffffff))
245 comb
+= Assert(o
[32:64] == 0)
247 # sign-extend to 64 bit
248 a_s
= Signal(64, reset_less
=True)
249 comb
+= a_s
.eq(exts(rs
, 32, 64))
250 comb
+= Assert(o
== ((a_s
<< b
[0:7]) & ((1 << 64)-1)))
252 # rlwinm, rlwnm, rlwimi
253 # *CAN* these even be 64-bit capable? I don't think they are.
254 with m
.Case(MicrOp
.OP_RLC
):
255 comb
+= Assume(ra
== 0)
256 comb
+= Assume(rec
.is_32bit
)
258 # Duplicate some signals so that they're much easier to find
260 # Pro-tip: when debugging, factor out expressions into
262 # signals, and search using a unique grep-tag (RLC in my case).
264 # debugging, resubstitute values to comply with surrounding
267 mrl
= Signal(64, reset_less
=True, name
='MASK_FOR_RLC')
269 comb
+= mrl
.eq(ml | mr
)
271 comb
+= mrl
.eq(ml
& mr
)
273 ainp
= Signal(64, reset_less
=True, name
='A_INP_FOR_RLC')
274 comb
+= ainp
.eq(field(rs
, 32, 63))
276 sh
= Signal(6, reset_less
=True, name
='SH_FOR_RLC')
277 comb
+= sh
.eq(b
[0:6])
279 exp_shl
= Signal(64, reset_less
=True,
280 name
='A_SHIFTED_LEFT_BY_SH_FOR_RLC')
281 comb
+= exp_shl
.eq((ainp
<< sh
) & 0xFFFFFFFF)
283 exp_shr
= Signal(64, reset_less
=True,
284 name
='A_SHIFTED_RIGHT_FOR_RLC')
285 comb
+= exp_shr
.eq((ainp
>> (32 - sh
)) & 0xFFFFFFFF)
287 exp_rot
= Signal(64, reset_less
=True,
288 name
='A_ROTATED_LEFT_FOR_RLC')
289 comb
+= exp_rot
.eq(exp_shl | exp_shr
)
291 exp_ol
= Signal(32, reset_less
=True,
292 name
='EXPECTED_OL_FOR_RLC')
293 comb
+= exp_ol
.eq(field((exp_rot
& mrl
) |
(ainp
& ~mrl
),
296 act_ol
= Signal(32, reset_less
=True, name
='ACTUAL_OL_FOR_RLC')
297 comb
+= act_ol
.eq(field(o
, 32, 63))
299 # If I uncomment the following lines, I can confirm that all
300 # 32-bit rotations work. If I uncomment only one of the
301 # following lines, I can confirm that all 32-bit rotations
302 # work. When I remove/recomment BOTH lines, however, the
303 # assertion fails. Why??
305 # comb += Assume(mr == 0xFFFFFFFF)
306 # comb += Assume(ml == 0xFFFFFFFF)
307 # with m.If(rec.is_32bit):
308 # comb += Assert(act_ol == exp_ol)
309 # comb += Assert(field(o, 0, 31) == 0)
312 with m
.Case(MicrOp
.OP_RLCR
):
314 with m
.Case(MicrOp
.OP_RLCL
):
316 with m
.Case(MicrOp
.OP_TERNLOG
):
317 lut
= dut
.fields
.FormTLI
.TLI
[:]
319 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
322 comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
323 with m
.Case(MicrOp
.OP_GREV
):
324 ra_bits
= Array(dut
.i
.ra
[i
] for i
in range(64))
325 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
326 # assert zero-extended
327 comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
329 idx
= dut
.i
.rb
[0:5] ^ i
330 comb
+= Assert(dut
.o
.o
.data
[i
]
334 idx
= dut
.i
.rb
[0:6] ^ i
335 comb
+= Assert(dut
.o
.o
.data
[i
]
341 # check that data ok was only enabled when op actioned
342 comb
+= Assert(dut
.o
.o
.ok
== o_ok
)
346 def _check_shl(self
, m
, dut
):
347 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
348 expected
= Signal(64)
349 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
350 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:6])[:32])
352 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:7])[:64])
353 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
354 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
356 def _check_shr(self
, m
, dut
):
357 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
358 expected
= Signal(64)
360 shift_in_s
= Signal(signed(128))
361 shift_roundtrip
= Signal(signed(128))
362 shift_in_u
= Signal(128)
363 shift_amt
= Signal(7)
364 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
366 shift_amt
.eq(dut
.i
.rb
[:6]),
367 shift_in_s
.eq(dut
.i
.rs
[:32].as_signed()),
368 shift_in_u
.eq(dut
.i
.rs
[:32]),
372 shift_amt
.eq(dut
.i
.rb
[:7]),
373 shift_in_s
.eq(dut
.i
.rs
.as_signed()),
374 shift_in_u
.eq(dut
.i
.rs
),
377 with m
.If(dut
.i
.ctx
.op
.is_signed
):
379 expected
.eq(shift_in_s
>> shift_amt
),
380 shift_roundtrip
.eq((shift_in_s
>> shift_amt
) << shift_amt
),
381 carry
.eq((shift_in_s
< 0) & (shift_roundtrip
!= shift_in_s
)),
385 expected
.eq(shift_in_u
>> shift_amt
),
388 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
389 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== Repl(carry
, 2))
391 def _check_rlc(self
, m
, dut
):
392 raise NotImplementedError
393 m
.submodules
.mask
= mask
= Mask()
396 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
398 def _check_rlcl(self
, m
, dut
):
399 raise NotImplementedError
401 def _check_rlcr(self
, m
, dut
):
402 raise NotImplementedError
404 def _check_extswsli(self
, m
, dut
):
405 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
406 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
407 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
) # all instrs. are 64-bit
408 expected
= Signal(64)
409 m
.d
.comb
+= expected
.eq((dut
.i
.rs
[0:32].as_signed() << dut
.i
.rb
[:6]))
410 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
411 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
413 def _check_ternlog(self
, m
, dut
):
414 lut
= dut
.fields
.FormTLI
.TLI
[:]
416 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
419 m
.d
.comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
420 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
422 def _check_grev32(self
, m
, dut
):
423 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
424 # assert zero-extended
425 m
.d
.comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
427 m
.d
.comb
+= eq_any_const(i
)
428 idx
= dut
.i
.rb
[0: 5] ^ i
429 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
430 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
432 def _check_grev64(self
, m
, dut
):
433 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
435 m
.d
.comb
+= eq_any_const(i
)
436 idx
= dut
.i
.rb
[0: 6] ^ i
437 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
438 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
441 class ALUTestCase(FHDLTestCase
):
442 def run_it(self
, which
):
443 module
= Driver(which
)
444 self
.assertFormal(module
, mode
="bmc", depth
=2)
445 self
.assertFormal(module
, mode
="cover", depth
=2)
448 self
.run_it(TstOp
.SHL
)
451 self
.run_it(TstOp
.SHR
)
454 self
.run_it(TstOp
.RLC
)
457 self
.run_it(TstOp
.RLCL
)
460 self
.run_it(TstOp
.RLCR
)
462 def test_extswsli(self
):
463 self
.run_it(TstOp
.EXTSWSLI
)
465 def test_ternlog(self
):
466 self
.run_it(TstOp
.TERNLOG
)
468 def test_grev32(self
):
469 self
.run_it(TstOp
.GREV32
)
471 def test_grev64(self
):
472 self
.run_it(TstOp
.GREV64
)
475 # check that all test cases are covered
477 assert callable(getattr(ALUTestCase
, f
"test_{i.name.lower()}"))
480 if __name__
== '__main__':