add formal proof for OP_RLC
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 24 Feb 2022 02:40:48 +0000 (18:40 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 24 Feb 2022 02:40:48 +0000 (18:40 -0800)
src/soc/fu/shift_rot/formal/proof_main_stage.py

index be0c4b169fd94a79795c5c7781ba8e66be8b7566..576a1658021d2d5fdbd578263d85ff62cf0a3dc5 100644 (file)
@@ -8,10 +8,11 @@ Links:
 import enum
 from shutil import which
 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
-                    signed, Array, Const, Value)
+                    signed, Array, Const, Value, unsigned)
 from nmigen.asserts import Assert, AnyConst, Assume, Cover
 from nmutil.formaltest import FHDLTestCase
-from nmigen.cli import rtlil
+from nmutil.sim_util import do_sim
+from nmigen.sim import Delay
 
 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
 from soc.fu.shift_rot.rotator import right_mask, left_mask
@@ -31,7 +32,8 @@ class TstOp(enum.Enum):
     also the formal proofs can be run in parallel."""
     SHL = MicrOp.OP_SHL
     SHR = MicrOp.OP_SHR
-    RLC = MicrOp.OP_RLC
+    RLC32 = MicrOp.OP_RLC, 32
+    RLC64 = MicrOp.OP_RLC, 64
     RLCL = MicrOp.OP_RLCL
     RLCR = MicrOp.OP_RLCR
     EXTSWSLI = MicrOp.OP_EXTSWSLI
@@ -60,7 +62,7 @@ class Mask(Elaboratable):
 
     def elaborate(self, platform):
         m = Module()
-        max_val = Const(~0, 64)
+        max_val = Const(~0, unsigned(64))
         max_bit = 63
         with m.If(self.start == 0):
             m.d.comb += self.out.eq(max_val << (max_bit - self.end))
@@ -72,6 +74,37 @@ class Mask(Elaboratable):
         return m
 
 
+class TstMask(unittest.TestCase):
+    def test_mask(self):
+        dut = Mask()
+
+        def case(start, end, expected):
+            with self.subTest(start=start, end=end):
+                yield dut.start.eq(start)
+                yield dut.end.eq(end)
+                yield Delay(1e-6)
+                out = yield dut.out
+                with self.subTest(out=hex(out), expected=hex(expected)):
+                    self.assertEqual(expected, out)
+
+        def process():
+            for start in range(64):
+                for end in range(64):
+                    expected = 0
+                    if start > end:
+                        for i in range(start, 64):
+                            expected |= 1 << (63 - i)
+                        for i in range(0, end + 1):
+                            expected |= 1 << (63 - i)
+                    else:
+                        for i in range(start, end + 1):
+                            expected |= 1 << (63 - i)
+                    yield from case(start, end, expected)
+        with do_sim(self, dut, [dut.start, dut.end, dut.out]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+
 def rotl64(v, amt):
     v |= Const(0, 64)  # convert to value at least 64-bits wide
     amt |= Const(0, 6)  # convert to value at least 6-bits wide
@@ -388,11 +421,42 @@ class Driver(Elaboratable):
         m.d.comb += Assert(dut.o.o.data == expected)
         m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
 
-    def _check_rlc(self, m, dut):
-        raise NotImplementedError
+    def _check_rlc32(self, m, dut):
+        m.d.comb += Assume(dut.i.ctx.op.is_32bit)
+        # rlwimi, rlwinm, and rlwnm
+
         m.submodules.mask = mask = Mask()
-        with m.If():
-            pass
+        expected = Signal(64)
+        rot = Signal(64)
+        m.d.comb += rot.eq(rotl32(dut.i.rs[:32], dut.i.rb[:5]))
+        m.d.comb += mask.start.eq(dut.fields.FormM.MB[:] + 32)
+        m.d.comb += mask.end.eq(dut.fields.FormM.ME[:] + 32)
+
+        # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
+        # the expression turns into a no-op
+        m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
+        m.d.comb += Assert(dut.o.o.data == expected)
+        m.d.comb += Assert(dut.o.xer_ca.data == 0)
+
+    def _check_rlc64(self, m, dut):
+        m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
+        # rldic and rldimi
+
+        # `rb` is always a 6-bit immediate
+        m.d.comb += Assume(dut.i.rb[6:] == 0)
+
+        m.submodules.mask = mask = Mask()
+        expected = Signal(64)
+        rot = Signal(64)
+        m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
+        mb = dut.fields.FormMD.mb[:]
+        m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
+        m.d.comb += mask.end.eq(63 - dut.i.rb[:6])
+
+        # for rldic, ra is guaranteed to be 0, so that part of
+        # the expression turns into a no-op
+        m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
+        m.d.comb += Assert(dut.o.o.data == expected)
         m.d.comb += Assert(dut.o.xer_ca.data == 0)
 
     def _check_rlcl(self, m, dut):
@@ -450,8 +514,11 @@ class ALUTestCase(FHDLTestCase):
     def test_shr(self):
         self.run_it(TstOp.SHR)
 
-    def test_rlc(self):
-        self.run_it(TstOp.RLC)
+    def test_rlc32(self):
+        self.run_it(TstOp.RLC32)
+
+    def test_rlc64(self):
+        self.run_it(TstOp.RLC64)
 
     def test_rlcl(self):
         self.run_it(TstOp.RLCL)