hdl/test/test_gfbinv: add formal proof for gfbinv
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 17 May 2024 08:16:04 +0000 (01:16 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 17 May 2024 08:16:04 +0000 (01:16 -0700)
src/nmigen_gf/hdl/test/test_gfbinv.py

index c639062ecf046fd1bcafd9d469d437b3b06fe615..c85131344d9a208f42f73ff58a4512988f044cb9 100644 (file)
@@ -10,7 +10,9 @@ https://bugs.libre-soc.org/show_bug.cgi?id=785
 """
 
 import unittest
-from nmigen.hdl.ast import Const, unsigned
+from nmigen.hdl.ast import \
+    AnyConst, Assert, Assume, Const, Past, Initial, ResetSignal, unsigned, Signal
+from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
 from nmigen_gf.reference.gfbinv import gfbinv
 from nmigen_gf.reference.is_irreducible import is_irreducible
@@ -18,6 +20,8 @@ from nmigen_gf.hdl.gfbinv import \
     py_gfbinv_algorithm, GFBInvShape, GFBInvFSMStage
 from nmigen_gf.reference.decode_reducing_polynomial import \
     decode_reducing_polynomial
+from nmigen_gf.hdl.decode_reducing_polynomial import DecodeReducingPolynomial
+from nmigen_gf.hdl.clmul import clmul, clmuladd
 from nmigen.sim import Delay, Tick
 from nmutil.sim_util import do_sim, hash_256
 from nmigen_gf.reference.state import ST
@@ -159,6 +163,106 @@ class TestGFBInv(FHDLTestCase):
     def test_32(self):
         self.tst(XLEN=32, full=False)
 
+    def tst_formal(self, XLEN):
+        # type: (int) -> None
+        m = Module()
+        shape = GFBInvShape(width=XLEN)
+        pspec = {}
+        dut = GFBInvFSMStage(pspec, shape)
+        m.submodules.dut = dut
+        i_data = dut.p.i_data
+        o_data = dut.n.o_data
+
+        # for some reason formal likes to keep reset asserted for the entire
+        # trace, force it to just be asserted at the beginning
+        m.d.comb += Assume(ResetSignal() == Initial())
+
+        def set_any_const(v, *, src_loc_at=0):
+            m.d.comb += v.eq(AnyConst(v.shape(), src_loc_at=src_loc_at + 1))
+
+        set_any_const(i_data.REDPOLY)
+        set_any_const(i_data.a)
+
+        drp = DecodeReducingPolynomial(XLEN)
+        m.submodules.drp = drp
+        m.d.comb += drp.REDPOLY.eq(i_data.REDPOLY)
+
+        def get_degree(v, *, src_loc_at=0):
+            width = v.shape().width
+            deg = Signal(range(-1, width), src_loc_at=src_loc_at + 1)
+            set_any_const(deg)
+            for i in range(width):
+                with m.If(v >> i == 1):
+                    m.d.comb += deg.eq(i)
+            with m.If(v == 0):
+                m.d.comb += deg.eq(-1)
+            return deg
+
+        rpoly_degree = get_degree(drp.reducing_polynomial)
+        a_degree = get_degree(i_data.a)
+        output_degree = get_degree(o_data.output)
+
+        m.d.comb += Assume(a_degree < rpoly_degree)
+
+        m.d.comb += dut.p.i_valid.eq(Past(Initial()))
+        m.d.comb += dut.n.i_ready.eq(0)
+        with m.If(Past(Initial(), clocks=XLEN * 2 + 2)):
+            m.d.comb += Assert(dut.n.o_valid)
+
+            rpoly_valid = Signal()
+            m.d.comb += rpoly_valid.eq(0)
+
+            poly_map = {}
+
+            for poly in range(2 ** (XLEN + 1)):
+                # too complex to check in hardware, so just list valid values
+                if is_irreducible(poly):
+                    REDPOLY = poly
+                    if REDPOLY >> XLEN != 0:
+                        REDPOLY %= 2 ** XLEN
+                        REDPOLY &= ~1
+                    elif poly == 2:
+                        REDPOLY = 0
+                    with m.If(drp.REDPOLY == REDPOLY):
+                        m.d.comb += rpoly_valid.eq(1)
+                        m.d.comb += Assert(drp.reducing_polynomial == poly)
+                        with self.subTest(poly=hex(poly),
+                                          REDPOLY=hex(REDPOLY)):
+                            self.assertNotIn(REDPOLY, poly_map)
+                        poly_map[REDPOLY] = poly
+
+            m.d.comb += Assume(rpoly_valid)
+
+            unreduced_inverse = Signal(XLEN * 2)
+            m.d.comb += unreduced_inverse.eq(clmul(i_data.a, o_data.output))
+
+            quot = Signal.like(drp.reducing_polynomial)
+            set_any_const(quot)
+            rem = Signal.like(drp.reducing_polynomial)
+            set_any_const(rem)
+            m.d.comb += Assume(unreduced_inverse == clmuladd(
+                drp.reducing_polynomial, quot, rem))
+            rem_degree = get_degree(rem)
+            m.d.comb += Assume(rem_degree < rpoly_degree)
+
+            inverse = Signal(XLEN * 2)
+            m.d.comb += inverse.eq(clmul(i_data.a, o_data.output))
+
+            output_valid = Signal()
+            output_valid_v = i_data.a == 0
+            output_valid_v &= o_data.output == 0
+            output_valid_v |= rem == 1
+            output_valid_v &= output_degree < rpoly_degree
+            m.d.comb += output_valid.eq(output_valid_v)
+            m.d.comb += Assert(output_valid)
+        self.assertFormal(m, depth=XLEN * 2 + 3)
+
+    def test_formal_4(self):
+        self.tst_formal(4)
+
+    def test_formal_8(self):
+        self.tst_formal(8)
+
 
 if __name__ == "__main__":
     unittest.main()