8e2952fc6fd015bf1b52ce662c3bb2f816fb9a88
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
10 from nmigen
import (Module
, Signal
, Elaboratable
, Mux
, Cat
, Repl
,
11 signed
, Const
, unsigned
)
12 from nmigen
.asserts
import Assert
, AnyConst
, Assume
13 from nmutil
.formaltest
import FHDLTestCase
14 from nmutil
.sim_util
import do_sim
15 from nmigen
.sim
import Delay
17 from soc
.fu
.shift_rot
.main_stage
import ShiftRotMainStage
18 from soc
.fu
.shift_rot
.pipe_data
import ShiftRotPipeSpec
19 from openpower
.decoder
.power_enums
import MicrOp
23 class TstOp(enum
.Enum
):
24 """ops we're testing, the idea is if we run a separate formal proof for
25 each instruction, we end up covering them all and each runs much faster,
26 also the formal proofs can be run in parallel."""
29 RLC32
= MicrOp
.OP_RLC
, 32
30 RLC64
= MicrOp
.OP_RLC
, 64
33 EXTSWSLI
= MicrOp
.OP_EXTSWSLI
34 TERNLOG
= MicrOp
.OP_TERNLOG
35 GREV32
= MicrOp
.OP_GREV
, 32
36 GREV64
= MicrOp
.OP_GREV
, 64
40 if isinstance(self
.value
, tuple):
45 def eq_any_const(sig
: Signal
):
46 return sig
.eq(AnyConst(sig
.shape(), src_loc_at
=1))
49 class Mask(Elaboratable
):
50 # copied from qemu's mask fn:
51 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
53 self
.start
= Signal(6)
57 def elaborate(self
, platform
):
59 max_val
= Const(~
0, unsigned(64))
61 with m
.If(self
.start
== 0):
62 m
.d
.comb
+= self
.out
.eq(max_val
<< (max_bit
- self
.end
))
63 with m
.Elif(self
.end
== max_bit
):
64 m
.d
.comb
+= self
.out
.eq(max_val
>> self
.start
)
66 ret
= (max_val
>> self
.start
) ^
((max_val
>> self
.end
) >> 1)
67 m
.d
.comb
+= self
.out
.eq(Mux(self
.start
> self
.end
, ~ret
, ret
))
71 class TstMask(unittest
.TestCase
):
75 def case(start
, end
, expected
):
76 with self
.subTest(start
=start
, end
=end
):
77 yield dut
.start
.eq(start
)
81 with self
.subTest(out
=hex(out
), expected
=hex(expected
)):
82 self
.assertEqual(expected
, out
)
85 for start
in range(64):
89 for i
in range(start
, 64):
90 expected |
= 1 << (63 - i
)
91 for i
in range(0, end
+ 1):
92 expected |
= 1 << (63 - i
)
94 for i
in range(start
, end
+ 1):
95 expected |
= 1 << (63 - i
)
96 yield from case(start
, end
, expected
)
97 with
do_sim(self
, dut
, [dut
.start
, dut
.end
, dut
.out
]) as sim
:
98 sim
.add_process(process
)
103 v |
= Const(0, 64) # convert to value at least 64-bits wide
104 amt |
= Const(0, 6) # convert to value at least 6-bits wide
105 return (Cat(v
[:64], v
[:64]) >> (64 - amt
[:6]))[:64]
109 v |
= Const(0, 32) # convert to value at least 32-bits wide
110 return rotl64(Cat(v
[:32], v
[:32]), amt
)
113 # This defines a module to drive the device under test and assert
114 # properties about its outputs
115 class Driver(Elaboratable
):
116 def __init__(self
, which
):
117 assert isinstance(which
, TstOp
)
120 def elaborate(self
, platform
):
124 pspec
= ShiftRotPipeSpec(id_wid
=2, parent_pspec
=None)
125 pspec
.draft_bitmanip
= True
126 m
.submodules
.dut
= dut
= ShiftRotMainStage(pspec
)
128 # Set inputs to formal variables
130 eq_any_const(dut
.i
.ctx
.op
.insn_type
),
131 eq_any_const(dut
.i
.ctx
.op
.fn_unit
),
132 eq_any_const(dut
.i
.ctx
.op
.imm_data
.data
),
133 eq_any_const(dut
.i
.ctx
.op
.imm_data
.ok
),
134 eq_any_const(dut
.i
.ctx
.op
.rc
.rc
),
135 eq_any_const(dut
.i
.ctx
.op
.rc
.ok
),
136 eq_any_const(dut
.i
.ctx
.op
.oe
.oe
),
137 eq_any_const(dut
.i
.ctx
.op
.oe
.ok
),
138 eq_any_const(dut
.i
.ctx
.op
.write_cr0
),
139 eq_any_const(dut
.i
.ctx
.op
.input_carry
),
140 eq_any_const(dut
.i
.ctx
.op
.output_carry
),
141 eq_any_const(dut
.i
.ctx
.op
.input_cr
),
142 eq_any_const(dut
.i
.ctx
.op
.is_32bit
),
143 eq_any_const(dut
.i
.ctx
.op
.is_signed
),
144 eq_any_const(dut
.i
.ctx
.op
.insn
),
145 eq_any_const(dut
.i
.xer_ca
),
146 eq_any_const(dut
.i
.ra
),
147 eq_any_const(dut
.i
.rb
),
148 eq_any_const(dut
.i
.rc
),
151 # check that the operation (op) is passed through (and muxid)
152 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
153 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
155 # we're only checking a particular operation:
156 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
== self
.which
.op
)
158 # dispatch to check fn for each op
159 getattr(self
, f
"_check_{self.which.name.lower()}")(m
, dut
)
163 def _check_shl(self
, m
, dut
):
164 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
165 expected
= Signal(64)
166 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
167 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:6])[:32])
169 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:7])[:64])
170 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
171 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
173 def _check_shr(self
, m
, dut
):
174 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
175 expected
= Signal(64)
177 shift_in_s
= Signal(signed(128))
178 shift_roundtrip
= Signal(signed(128))
179 shift_in_u
= Signal(128)
180 shift_amt
= Signal(7)
181 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
183 shift_amt
.eq(dut
.i
.rb
[:6]),
184 shift_in_s
.eq(dut
.i
.rs
[:32].as_signed()),
185 shift_in_u
.eq(dut
.i
.rs
[:32]),
189 shift_amt
.eq(dut
.i
.rb
[:7]),
190 shift_in_s
.eq(dut
.i
.rs
.as_signed()),
191 shift_in_u
.eq(dut
.i
.rs
),
194 with m
.If(dut
.i
.ctx
.op
.is_signed
):
196 expected
.eq(shift_in_s
>> shift_amt
),
197 shift_roundtrip
.eq((shift_in_s
>> shift_amt
) << shift_amt
),
198 carry
.eq((shift_in_s
< 0) & (shift_roundtrip
!= shift_in_s
)),
202 expected
.eq(shift_in_u
>> shift_amt
),
205 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
206 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== Repl(carry
, 2))
208 def _check_rlc32(self
, m
, dut
):
209 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
210 # rlwimi, rlwinm, and rlwnm
212 m
.submodules
.mask
= mask
= Mask()
213 expected
= Signal(64)
215 m
.d
.comb
+= rot
.eq(rotl32(dut
.i
.rs
[:32], dut
.i
.rb
[:5]))
216 m
.d
.comb
+= mask
.start
.eq(dut
.fields
.FormM
.MB
[:] + 32)
217 m
.d
.comb
+= mask
.end
.eq(dut
.fields
.FormM
.ME
[:] + 32)
219 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
220 # the expression turns into a no-op
221 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
222 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
223 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
225 def _check_rlc64(self
, m
, dut
):
226 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
229 # `rb` is always a 6-bit immediate
230 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
232 m
.submodules
.mask
= mask
= Mask()
233 expected
= Signal(64)
235 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
236 mb
= dut
.fields
.FormMD
.mb
[:]
237 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
238 m
.d
.comb
+= mask
.end
.eq(63 - dut
.i
.rb
[:6])
240 # for rldic, ra is guaranteed to be 0, so that part of
241 # the expression turns into a no-op
242 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
243 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
244 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
246 def _check_rlcl(self
, m
, dut
):
247 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
250 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
251 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
253 m
.submodules
.mask
= mask
= Mask()
254 m
.d
.comb
+= mask
.end
.eq(63)
255 mb
= dut
.fields
.FormMD
.mb
[:]
256 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
259 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
261 expected
= Signal(64)
262 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
264 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
265 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
267 def _check_rlcr(self
, m
, dut
):
268 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
271 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
272 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
274 m
.submodules
.mask
= mask
= Mask()
275 m
.d
.comb
+= mask
.start
.eq(0)
276 me
= dut
.fields
.FormMD
.me
[:]
277 m
.d
.comb
+= mask
.end
.eq(Cat(me
[1:6], me
[0]))
280 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
282 expected
= Signal(64)
283 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
285 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
286 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
288 def _check_extswsli(self
, m
, dut
):
289 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
290 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
291 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
) # all instrs. are 64-bit
292 expected
= Signal(64)
293 m
.d
.comb
+= expected
.eq((dut
.i
.rs
[0:32].as_signed() << dut
.i
.rb
[:6]))
294 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
295 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
297 def _check_ternlog(self
, m
, dut
):
298 lut
= dut
.fields
.FormTLI
.TLI
[:]
300 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
303 m
.d
.comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
304 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
306 def _check_grev32(self
, m
, dut
):
307 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
308 # assert zero-extended
309 m
.d
.comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
311 m
.d
.comb
+= eq_any_const(i
)
312 idx
= dut
.i
.rb
[0: 5] ^ i
313 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
314 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
316 def _check_grev64(self
, m
, dut
):
317 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
319 m
.d
.comb
+= eq_any_const(i
)
320 idx
= dut
.i
.rb
[0: 6] ^ i
321 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
322 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
325 class ALUTestCase(FHDLTestCase
):
326 def run_it(self
, which
):
327 module
= Driver(which
)
328 self
.assertFormal(module
, mode
="bmc", depth
=2)
329 self
.assertFormal(module
, mode
="cover", depth
=2)
332 self
.run_it(TstOp
.SHL
)
335 self
.run_it(TstOp
.SHR
)
337 def test_rlc32(self
):
338 self
.run_it(TstOp
.RLC32
)
340 def test_rlc64(self
):
341 self
.run_it(TstOp
.RLC64
)
344 self
.run_it(TstOp
.RLCL
)
347 self
.run_it(TstOp
.RLCR
)
349 def test_extswsli(self
):
350 self
.run_it(TstOp
.EXTSWSLI
)
352 def test_ternlog(self
):
353 self
.run_it(TstOp
.TERNLOG
)
355 def test_grev32(self
):
356 self
.run_it(TstOp
.GREV32
)
358 def test_grev64(self
):
359 self
.run_it(TstOp
.GREV64
)
362 # check that all test cases are covered
364 assert callable(getattr(ALUTestCase
, f
"test_{i.name.lower()}"))
367 if __name__
== '__main__':