From: Jacob Lifshay Date: Fri, 17 May 2024 08:16:04 +0000 (-0700) Subject: hdl/test/test_gfbinv: add formal proof for gfbinv X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=35e5f677384814056e63e17e4c5005366017cebf;p=nmigen-gf.git hdl/test/test_gfbinv: add formal proof for gfbinv --- diff --git a/src/nmigen_gf/hdl/test/test_gfbinv.py b/src/nmigen_gf/hdl/test/test_gfbinv.py index c639062..c851313 100644 --- a/src/nmigen_gf/hdl/test/test_gfbinv.py +++ b/src/nmigen_gf/hdl/test/test_gfbinv.py @@ -10,7 +10,9 @@ https://bugs.libre-soc.org/show_bug.cgi?id=785 """ import unittest -from nmigen.hdl.ast import Const, unsigned +from nmigen.hdl.ast import \ + AnyConst, Assert, Assume, Const, Past, Initial, ResetSignal, unsigned, Signal +from nmigen.hdl.dsl import Module from nmutil.formaltest import FHDLTestCase from nmigen_gf.reference.gfbinv import gfbinv from nmigen_gf.reference.is_irreducible import is_irreducible @@ -18,6 +20,8 @@ from nmigen_gf.hdl.gfbinv import \ py_gfbinv_algorithm, GFBInvShape, GFBInvFSMStage from nmigen_gf.reference.decode_reducing_polynomial import \ decode_reducing_polynomial +from nmigen_gf.hdl.decode_reducing_polynomial import DecodeReducingPolynomial +from nmigen_gf.hdl.clmul import clmul, clmuladd from nmigen.sim import Delay, Tick from nmutil.sim_util import do_sim, hash_256 from nmigen_gf.reference.state import ST @@ -159,6 +163,106 @@ class TestGFBInv(FHDLTestCase): def test_32(self): self.tst(XLEN=32, full=False) + def tst_formal(self, XLEN): + # type: (int) -> None + m = Module() + shape = GFBInvShape(width=XLEN) + pspec = {} + dut = GFBInvFSMStage(pspec, shape) + m.submodules.dut = dut + i_data = dut.p.i_data + o_data = dut.n.o_data + + # for some reason formal likes to keep reset asserted for the entire + # trace, force it to just be asserted at the beginning + m.d.comb += Assume(ResetSignal() == Initial()) + + def set_any_const(v, *, src_loc_at=0): + m.d.comb += v.eq(AnyConst(v.shape(), src_loc_at=src_loc_at + 1)) + + set_any_const(i_data.REDPOLY) + set_any_const(i_data.a) + + drp = DecodeReducingPolynomial(XLEN) + m.submodules.drp = drp + m.d.comb += drp.REDPOLY.eq(i_data.REDPOLY) + + def get_degree(v, *, src_loc_at=0): + width = v.shape().width + deg = Signal(range(-1, width), src_loc_at=src_loc_at + 1) + set_any_const(deg) + for i in range(width): + with m.If(v >> i == 1): + m.d.comb += deg.eq(i) + with m.If(v == 0): + m.d.comb += deg.eq(-1) + return deg + + rpoly_degree = get_degree(drp.reducing_polynomial) + a_degree = get_degree(i_data.a) + output_degree = get_degree(o_data.output) + + m.d.comb += Assume(a_degree < rpoly_degree) + + m.d.comb += dut.p.i_valid.eq(Past(Initial())) + m.d.comb += dut.n.i_ready.eq(0) + with m.If(Past(Initial(), clocks=XLEN * 2 + 2)): + m.d.comb += Assert(dut.n.o_valid) + + rpoly_valid = Signal() + m.d.comb += rpoly_valid.eq(0) + + poly_map = {} + + for poly in range(2 ** (XLEN + 1)): + # too complex to check in hardware, so just list valid values + if is_irreducible(poly): + REDPOLY = poly + if REDPOLY >> XLEN != 0: + REDPOLY %= 2 ** XLEN + REDPOLY &= ~1 + elif poly == 2: + REDPOLY = 0 + with m.If(drp.REDPOLY == REDPOLY): + m.d.comb += rpoly_valid.eq(1) + m.d.comb += Assert(drp.reducing_polynomial == poly) + with self.subTest(poly=hex(poly), + REDPOLY=hex(REDPOLY)): + self.assertNotIn(REDPOLY, poly_map) + poly_map[REDPOLY] = poly + + m.d.comb += Assume(rpoly_valid) + + unreduced_inverse = Signal(XLEN * 2) + m.d.comb += unreduced_inverse.eq(clmul(i_data.a, o_data.output)) + + quot = Signal.like(drp.reducing_polynomial) + set_any_const(quot) + rem = Signal.like(drp.reducing_polynomial) + set_any_const(rem) + m.d.comb += Assume(unreduced_inverse == clmuladd( + drp.reducing_polynomial, quot, rem)) + rem_degree = get_degree(rem) + m.d.comb += Assume(rem_degree < rpoly_degree) + + inverse = Signal(XLEN * 2) + m.d.comb += inverse.eq(clmul(i_data.a, o_data.output)) + + output_valid = Signal() + output_valid_v = i_data.a == 0 + output_valid_v &= o_data.output == 0 + output_valid_v |= rem == 1 + output_valid_v &= output_degree < rpoly_degree + m.d.comb += output_valid.eq(output_valid_v) + m.d.comb += Assert(output_valid) + self.assertFormal(m, depth=XLEN * 2 + 3) + + def test_formal_4(self): + self.tst_formal(4) + + def test_formal_8(self): + self.tst_formal(8) + if __name__ == "__main__": unittest.main()