hdl/test/test_gfbmadd: add formal proof of gfbmadd bug_785_gfb_insns
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 18 May 2024 00:41:07 +0000 (17:41 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 18 May 2024 00:41:07 +0000 (17:41 -0700)
src/nmigen_gf/hdl/test/test_gfbmadd.py

index 6b8da5744ab2bd175f068c33a626e3199db844f8..b2a9f5604198df36152b69d48e5136e46ea9adc0 100644 (file)
@@ -10,11 +10,15 @@ https://bugs.libre-soc.org/show_bug.cgi?id=785
 """
 
 import unittest
-from nmigen.hdl.ast import Const, unsigned
+from nmigen.hdl.ast import AnyConst, Assert, Assume, Const, Initial, \
+    Past, ResetSignal, Signal, unsigned
+from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
 from nmigen_gf.reference.gfbmadd import gfbmadd
+from nmigen_gf.hdl.decode_reducing_polynomial import DecodeReducingPolynomial
 from nmigen_gf.hdl.gfbmadd import \
     py_gfbmadd_algorithm, GFBMAddFSMStage, GFBMAddShape
+from nmigen_gf.hdl.clmul import clmuladd
 from nmigen.sim import Delay, Tick
 from nmutil.sim_util import do_sim, hash_256
 from nmigen_gf.reference.state import ST
@@ -162,6 +166,71 @@ class TestGFBMAdd(FHDLTestCase):
     def test_64(self):
         self.tst(64)
 
+    def tst_formal(self, XLEN):
+        # type: (int) -> None
+        m = Module()
+        shape = GFBMAddShape(width=XLEN)
+        pspec = {}
+        dut = GFBMAddFSMStage(pspec, shape)
+        m.submodules.dut = dut
+        i_data = dut.p.i_data
+        o_data = dut.n.o_data
+
+        # for some reason formal likes to keep reset asserted for the entire
+        # trace, force it to just be asserted at the beginning
+        m.d.comb += Assume(ResetSignal() == Initial())
+
+        def set_any_const(v, *, src_loc_at=0):
+            m.d.comb += v.eq(AnyConst(v.shape(), src_loc_at=src_loc_at + 1))
+
+        set_any_const(i_data.REDPOLY)
+        set_any_const(i_data.factor1)
+        set_any_const(i_data.factor2)
+        set_any_const(i_data.term)
+
+        drp = DecodeReducingPolynomial(XLEN)
+        m.submodules.drp = drp
+        m.d.comb += drp.REDPOLY.eq(i_data.REDPOLY)
+
+        def get_degree(v, *, src_loc_at=0):
+            width = v.shape().width
+            deg = Signal(range(-1, width), src_loc_at=src_loc_at + 1)
+            set_any_const(deg)
+            for i in range(width):
+                with m.If(v >> i == 1):
+                    m.d.comb += deg.eq(i)
+            with m.If(v == 0):
+                m.d.comb += deg.eq(-1)
+            return deg
+
+        rpoly_degree = get_degree(drp.reducing_polynomial)
+
+        m.d.comb += dut.p.i_valid.eq(Past(Initial()))
+        m.d.comb += dut.n.i_ready.eq(0)
+        with m.If(Past(Initial(), clocks=shape.step_count + 1)):
+            m.d.comb += Assert(dut.n.o_valid)
+
+            unreduced_output = Signal(XLEN * 2)
+            m.d.comb += unreduced_output.eq(
+                clmuladd(i_data.factor1, i_data.factor2, i_data.term))
+
+            quot = Signal.like(drp.reducing_polynomial)
+            set_any_const(quot)
+            rem = Signal.like(drp.reducing_polynomial)
+            set_any_const(rem)
+            m.d.comb += Assume(unreduced_output == clmuladd(
+                drp.reducing_polynomial, quot, rem))
+            rem_degree = get_degree(rem)
+            m.d.comb += Assume(rem_degree < rpoly_degree)
+            m.d.comb += Assert(rem == o_data.output)
+        self.assertFormal(m, depth=shape.step_count + 1)
+
+    def test_formal_4(self):
+        self.tst_formal(4)
+
+    def test_formal_5(self):
+        self.tst_formal(5)
+
 
 if __name__ == "__main__":
     unittest.main()