From 442546d5bffafb8571f50b1a4dc48e83ace1ce47 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 22 Feb 2022 01:33:16 -0800 Subject: [PATCH] speed up shift/rot formal proof by running stuff in parallel it now runs in about 1m30s in `pytest -n auto .py` also start work on rewriting proofs to hopefully work better, all of grev, ternlog, extswsli, shl, and shr are completed. --- .../fu/shift_rot/formal/proof_main_stage.py | 264 +++++++++++++++--- 1 file changed, 231 insertions(+), 33 deletions(-) diff --git a/src/soc/fu/shift_rot/formal/proof_main_stage.py b/src/soc/fu/shift_rot/formal/proof_main_stage.py index c4dd461d..be0c4b16 100644 --- a/src/soc/fu/shift_rot/formal/proof_main_stage.py +++ b/src/soc/fu/shift_rot/formal/proof_main_stage.py @@ -1,12 +1,14 @@ -# Proof of correctness for partitioned equal signal combiner +# Proof of correctness for shift/rotate FU # Copyright (C) 2020 Michael Nolan """ Links: * https://bugs.libre-soc.org/show_bug.cgi?id=340 """ +import enum +from shutil import which from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl, - signed, Array) + signed, Array, Const, Value) from nmigen.asserts import Assert, AnyConst, Assume, Cover from nmutil.formaltest import FHDLTestCase from nmigen.cli import rtlil @@ -22,42 +24,118 @@ import unittest from nmutil.extend import exts +@enum.unique +class TstOp(enum.Enum): + """ops we're testing, the idea is if we run a separate formal proof for + each instruction, we end up covering them all and each runs much faster, + also the formal proofs can be run in parallel.""" + SHL = MicrOp.OP_SHL + SHR = MicrOp.OP_SHR + RLC = MicrOp.OP_RLC + RLCL = MicrOp.OP_RLCL + RLCR = MicrOp.OP_RLCR + EXTSWSLI = MicrOp.OP_EXTSWSLI + TERNLOG = MicrOp.OP_TERNLOG + GREV32 = MicrOp.OP_GREV, 32 + GREV64 = MicrOp.OP_GREV, 64 + + @property + def op(self): + if isinstance(self.value, tuple): + return self.value[0] + return self.value + + +def eq_any_const(sig: Signal): + return sig.eq(AnyConst(sig.shape(), src_loc_at=1)) + + +class Mask(Elaboratable): + # copied from qemu's mask fn: + # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21 + def __init__(self): + self.start = Signal(6) + self.end = Signal(6) + self.out = Signal(64) + + def elaborate(self, platform): + m = Module() + max_val = Const(~0, 64) + max_bit = 63 + with m.If(self.start == 0): + m.d.comb += self.out.eq(max_val << (max_bit - self.end)) + with m.Elif(self.end == max_bit): + m.d.comb += self.out.eq(max_val >> self.start) + with m.Else(): + ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1) + m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret)) + return m + + +def rotl64(v, amt): + v |= Const(0, 64) # convert to value at least 64-bits wide + amt |= Const(0, 6) # convert to value at least 6-bits wide + return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64] + + +def rotl32(v, amt): + v |= Const(0, 32) # convert to value at least 32-bits wide + return rotl64(Cat(v[:32], v[:32]), amt) + + # This defines a module to drive the device under test and assert # properties about its outputs class Driver(Elaboratable): - def __init__(self): - # inputs and outputs - pass + def __init__(self, which): + assert isinstance(which, TstOp) + self.which = which def elaborate(self, platform): m = Module() comb = m.d.comb - rec = CompSROpSubset() - # Setup random inputs for dut.op. do them explicitly so that - # we can see which ones cause failures in the debug report - # for p in rec.ports(): - # comb += p.eq(AnyConst(p.width)) - comb += rec.insn_type.eq(AnyConst(rec.insn_type.width)) - comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width)) - comb += rec.imm_data.data.eq(AnyConst(rec.imm_data.data.width)) - comb += rec.imm_data.ok.eq(AnyConst(rec.imm_data.ok.width)) - comb += rec.rc.rc.eq(AnyConst(rec.rc.rc.width)) - comb += rec.rc.ok.eq(AnyConst(rec.rc.ok.width)) - comb += rec.oe.oe.eq(AnyConst(rec.oe.oe.width)) - comb += rec.oe.ok.eq(AnyConst(rec.oe.ok.width)) - comb += rec.write_cr0.eq(AnyConst(rec.write_cr0.width)) - comb += rec.input_carry.eq(AnyConst(rec.input_carry.width)) - comb += rec.output_carry.eq(AnyConst(rec.output_carry.width)) - comb += rec.input_cr.eq(AnyConst(rec.input_cr.width)) - comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width)) - comb += rec.is_signed.eq(AnyConst(rec.is_signed.width)) - comb += rec.insn.eq(AnyConst(rec.insn.width)) - pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None) pspec.draft_bitmanip = True m.submodules.dut = dut = ShiftRotMainStage(pspec) + # Set inputs to formal variables + comb += [ + eq_any_const(dut.i.ctx.op.insn_type), + eq_any_const(dut.i.ctx.op.fn_unit), + eq_any_const(dut.i.ctx.op.imm_data.data), + eq_any_const(dut.i.ctx.op.imm_data.ok), + eq_any_const(dut.i.ctx.op.rc.rc), + eq_any_const(dut.i.ctx.op.rc.ok), + eq_any_const(dut.i.ctx.op.oe.oe), + eq_any_const(dut.i.ctx.op.oe.ok), + eq_any_const(dut.i.ctx.op.write_cr0), + eq_any_const(dut.i.ctx.op.input_carry), + eq_any_const(dut.i.ctx.op.output_carry), + eq_any_const(dut.i.ctx.op.input_cr), + eq_any_const(dut.i.ctx.op.is_32bit), + eq_any_const(dut.i.ctx.op.is_signed), + eq_any_const(dut.i.ctx.op.insn), + eq_any_const(dut.i.xer_ca), + eq_any_const(dut.i.ra), + eq_any_const(dut.i.rb), + eq_any_const(dut.i.rc), + ] + + # check that the operation (op) is passed through (and muxid) + comb += Assert(dut.o.ctx.op == dut.i.ctx.op) + comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid) + + # we're only checking a particular operation: + comb += Assume(dut.i.ctx.op.insn_type == self.which.op) + + # dispatch to check fn for each op + getattr(self, f"_check_{self.which.name.lower()}")(m, dut) + + return m + + # all following code in elaborate is kept for ease of reference, to be + # deleted once this proof is completed. + # convenience variables rs = dut.i.rs # register to shift b = dut.i.rb # register containing amount to shift by @@ -265,18 +343,138 @@ class Driver(Elaboratable): return m + def _check_shl(self, m, dut): + m.d.comb += Assume(dut.i.ra == 0) + expected = Signal(64) + with m.If(dut.i.ctx.op.is_32bit): + m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32]) + with m.Else(): + m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64]) + m.d.comb += Assert(dut.o.o.data == expected) + m.d.comb += Assert(dut.o.xer_ca.data == 0) + + def _check_shr(self, m, dut): + m.d.comb += Assume(dut.i.ra == 0) + expected = Signal(64) + carry = Signal() + shift_in_s = Signal(signed(128)) + shift_roundtrip = Signal(signed(128)) + shift_in_u = Signal(128) + shift_amt = Signal(7) + with m.If(dut.i.ctx.op.is_32bit): + m.d.comb += [ + shift_amt.eq(dut.i.rb[:6]), + shift_in_s.eq(dut.i.rs[:32].as_signed()), + shift_in_u.eq(dut.i.rs[:32]), + ] + with m.Else(): + m.d.comb += [ + shift_amt.eq(dut.i.rb[:7]), + shift_in_s.eq(dut.i.rs.as_signed()), + shift_in_u.eq(dut.i.rs), + ] + + with m.If(dut.i.ctx.op.is_signed): + m.d.comb += [ + expected.eq(shift_in_s >> shift_amt), + shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt), + carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)), + ] + with m.Else(): + m.d.comb += [ + expected.eq(shift_in_u >> shift_amt), + carry.eq(0), + ] + m.d.comb += Assert(dut.o.o.data == expected) + m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2)) + + def _check_rlc(self, m, dut): + raise NotImplementedError + m.submodules.mask = mask = Mask() + with m.If(): + pass + m.d.comb += Assert(dut.o.xer_ca.data == 0) + + def _check_rlcl(self, m, dut): + raise NotImplementedError + + def _check_rlcr(self, m, dut): + raise NotImplementedError + + def _check_extswsli(self, m, dut): + m.d.comb += Assume(dut.i.ra == 0) + m.d.comb += Assume(dut.i.rb[6:] == 0) + m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit + expected = Signal(64) + m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6])) + m.d.comb += Assert(dut.o.o.data == expected) + m.d.comb += Assert(dut.o.xer_ca.data == 0) + + def _check_ternlog(self, m, dut): + lut = dut.fields.FormTLI.TLI[:] + for i in range(64): + idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i]) + for j in range(8): + with m.If(j == idx): + m.d.comb += Assert(dut.o.o.data[i] == lut[j]) + m.d.comb += Assert(dut.o.xer_ca.data == 0) + + def _check_grev32(self, m, dut): + m.d.comb += Assume(dut.i.ctx.op.is_32bit) + # assert zero-extended + m.d.comb += Assert(dut.o.o.data[32:] == 0) + i = Signal(5) + m.d.comb += eq_any_const(i) + idx = dut.i.rb[0: 5] ^ i + m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0]) + m.d.comb += Assert(dut.o.xer_ca.data == 0) + + def _check_grev64(self, m, dut): + m.d.comb += Assume(~dut.i.ctx.op.is_32bit) + i = Signal(6) + m.d.comb += eq_any_const(i) + idx = dut.i.rb[0: 6] ^ i + m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0]) + m.d.comb += Assert(dut.o.xer_ca.data == 0) + class ALUTestCase(FHDLTestCase): - def test_formal(self): - module = Driver() + def run_it(self, which): + module = Driver(which) self.assertFormal(module, mode="bmc", depth=2) self.assertFormal(module, mode="cover", depth=2) - def test_ilang(self): - dut = Driver() - vl = rtlil.convert(dut, ports=[]) - with open("main_stage.il", "w") as f: - f.write(vl) + def test_shl(self): + self.run_it(TstOp.SHL) + + def test_shr(self): + self.run_it(TstOp.SHR) + + def test_rlc(self): + self.run_it(TstOp.RLC) + + def test_rlcl(self): + self.run_it(TstOp.RLCL) + + def test_rlcr(self): + self.run_it(TstOp.RLCR) + + def test_extswsli(self): + self.run_it(TstOp.EXTSWSLI) + + def test_ternlog(self): + self.run_it(TstOp.TERNLOG) + + def test_grev32(self): + self.run_it(TstOp.GREV32) + + def test_grev64(self): + self.run_it(TstOp.GREV64) + + +# check that all test cases are covered +for i in TstOp: + assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}")) if __name__ == '__main__': -- 2.30.2