add formal proof for MultiPriorityPicker
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 4 Aug 2022 04:09:53 +0000 (21:09 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 4 Aug 2022 04:09:53 +0000 (21:09 -0700)
src/nmutil/formal/test_picker.py
src/nmutil/picker.py

index 573129c95c0d9a4aaca4c3df8e28bc2e64c2bb93..1abf56ff101101e444367c000b12ce64bdc488db 100644 (file)
@@ -1,8 +1,10 @@
 # SPDX-License-Identifier: LGPL-3-or-later
 # Copyright 2022 Jacob Lifshay
 
+from functools import reduce
+import operator
 import unittest
-from nmigen.hdl.ast import AnyConst, Assert, Signal, Const
+from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, Array, Shape, Mux
 from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
 from nmutil.picker import PriorityPicker, MultiPriorityPicker
@@ -224,5 +226,237 @@ class TestPriorityPicker(FHDLTestCase):
         self.tst(wid=64, msb_mode=True, reverse_i=True, reverse_o=True)
 
 
+class TestMultiPriorityPicker(FHDLTestCase):
+    def tst(self, *, cls=MultiPriorityPicker, wid, levels, indices, multi_in):
+        assert isinstance(wid, int) and wid >= 1
+        assert isinstance(levels, int) and 1 <= levels <= wid
+        assert isinstance(indices, bool)
+        assert isinstance(multi_in, bool)
+        dut = cls(wid=wid, levels=levels, indices=indices, multi_in=multi_in)
+        self.assertEqual(wid, dut.wid)
+        self.assertEqual(levels, dut.levels)
+        self.assertEqual(indices, dut.indices)
+        self.assertEqual(multi_in, dut.multi_in)
+        expected_ports = []
+        if multi_in:
+            self.assertIsInstance(dut.i, Array)
+            self.assertEqual(len(dut.i), levels)
+            for i in dut.i:
+                self.assertIsInstance(i, Signal)
+                self.assertEqual(len(i), wid)
+                expected_ports.append(i)
+        else:
+            self.assertIsInstance(dut.i, Signal)
+            self.assertEqual(len(dut.i), wid)
+            expected_ports.append(dut.i)
+
+        self.assertIsInstance(dut.o, Array)
+        self.assertEqual(len(dut.o), levels)
+        for o in dut.o:
+            self.assertIsInstance(o, Signal)
+            self.assertEqual(len(o), wid)
+            expected_ports.append(o)
+
+        self.assertEqual(len(dut.en_o), levels)
+        expected_ports.append(dut.en_o)
+
+        if indices:
+            expected_idx_o_shape = Shape.cast(range(levels))
+            if levels <= 1:
+                expected_idx_o_shape = Shape(0, False)
+            self.assertIsInstance(dut.idx_o, Array)
+            self.assertEqual(len(dut.idx_o), levels)
+            for idx_o in dut.idx_o:
+                self.assertIsInstance(idx_o, Signal)
+                self.assertEqual(idx_o.shape(), expected_idx_o_shape)
+                expected_ports.append(idx_o)
+        else:
+            self.assertFalse(hasattr(dut, "idx_o"))
+
+        self.assertListEqual(expected_ports, dut.ports())
+
+        m = Module()
+        m.submodules.dut = dut
+        if multi_in:
+            for i in dut.i:
+                m.d.comb += i.eq(AnyConst(wid))
+        else:
+            m.d.comb += dut.i.eq(AnyConst(wid))
+
+        prev_set = 0
+        for o, en_o in zip(dut.o, dut.en_o):
+            # assert o only has zero or one bit set
+            m.d.comb += Assert((o & (o - 1)) == 0)
+            # assert o doesn't overlap any previous outputs
+            m.d.comb += Assert((o & prev_set) == 0)
+            prev_set |= o
+
+            m.d.comb += Assert((o != 0) == en_o)
+
+        prev_set = Const(0, wid)
+        priority_pickers = [PriorityPicker(wid) for _ in range(levels)]
+        for level in range(levels):
+            pp = priority_pickers[level]
+            setattr(m.submodules, f"pp_{level}", pp)
+            inp = dut.i[level] if multi_in else dut.i
+            m.d.comb += pp.i.eq(inp & ~prev_set)
+            cur_set = Signal(wid, name=f"cur_set_{level}")
+            m.d.comb += cur_set.eq(prev_set | pp.o)
+            prev_set = cur_set
+            m.d.comb += Assert(pp.o == dut.o[level])
+            expected_idx = Signal(32, name=f"expected_idx_{level}")
+            number_of_prev_en_o_set = reduce(
+                operator.add, (i.en_o for i in priority_pickers[:level]), 0)
+            m.d.comb += expected_idx.eq(number_of_prev_en_o_set)
+            if indices:
+                m.d.comb += Assert(expected_idx == dut.idx_o[level])
+
+        self.assertFormal(m)
+
+    def test_4_levels_1_idxs_f_mi_f(self):
+        self.tst(wid=4, levels=1, indices=False, multi_in=False)
+
+    def test_4_levels_1_idxs_f_mi_t(self):
+        self.tst(wid=4, levels=1, indices=False, multi_in=True)
+
+    def test_4_levels_1_idxs_t_mi_f(self):
+        self.tst(wid=4, levels=1, indices=True, multi_in=False)
+
+    def test_4_levels_1_idxs_t_mi_t(self):
+        self.tst(wid=4, levels=1, indices=True, multi_in=True)
+
+    def test_4_levels_2_idxs_f_mi_f(self):
+        self.tst(wid=4, levels=2, indices=False, multi_in=False)
+
+    def test_4_levels_2_idxs_f_mi_t(self):
+        self.tst(wid=4, levels=2, indices=False, multi_in=True)
+
+    def test_4_levels_2_idxs_t_mi_f(self):
+        self.tst(wid=4, levels=2, indices=True, multi_in=False)
+
+    def test_4_levels_2_idxs_t_mi_t(self):
+        self.tst(wid=4, levels=2, indices=True, multi_in=True)
+
+    def test_4_levels_3_idxs_f_mi_f(self):
+        self.tst(wid=4, levels=3, indices=False, multi_in=False)
+
+    def test_4_levels_3_idxs_f_mi_t(self):
+        self.tst(wid=4, levels=3, indices=False, multi_in=True)
+
+    def test_4_levels_3_idxs_t_mi_f(self):
+        self.tst(wid=4, levels=3, indices=True, multi_in=False)
+
+    def test_4_levels_3_idxs_t_mi_t(self):
+        self.tst(wid=4, levels=3, indices=True, multi_in=True)
+
+    def test_4_levels_4_idxs_f_mi_f(self):
+        self.tst(wid=4, levels=4, indices=False, multi_in=False)
+
+    def test_4_levels_4_idxs_f_mi_t(self):
+        self.tst(wid=4, levels=4, indices=False, multi_in=True)
+
+    def test_4_levels_4_idxs_t_mi_f(self):
+        self.tst(wid=4, levels=4, indices=True, multi_in=False)
+
+    def test_4_levels_4_idxs_t_mi_t(self):
+        self.tst(wid=4, levels=4, indices=True, multi_in=True)
+
+    def test_8_levels_1_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=1, indices=False, multi_in=False)
+
+    def test_8_levels_1_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=1, indices=False, multi_in=True)
+
+    def test_8_levels_1_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=1, indices=True, multi_in=False)
+
+    def test_8_levels_1_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=1, indices=True, multi_in=True)
+
+    def test_8_levels_2_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=2, indices=False, multi_in=False)
+
+    def test_8_levels_2_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=2, indices=False, multi_in=True)
+
+    def test_8_levels_2_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=2, indices=True, multi_in=False)
+
+    def test_8_levels_2_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=2, indices=True, multi_in=True)
+
+    def test_8_levels_3_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=3, indices=False, multi_in=False)
+
+    def test_8_levels_3_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=3, indices=False, multi_in=True)
+
+    def test_8_levels_3_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=3, indices=True, multi_in=False)
+
+    def test_8_levels_3_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=3, indices=True, multi_in=True)
+
+    def test_8_levels_4_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=4, indices=False, multi_in=False)
+
+    def test_8_levels_4_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=4, indices=False, multi_in=True)
+
+    def test_8_levels_4_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=4, indices=True, multi_in=False)
+
+    def test_8_levels_4_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=4, indices=True, multi_in=True)
+
+    def test_8_levels_5_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=5, indices=False, multi_in=False)
+
+    def test_8_levels_5_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=5, indices=False, multi_in=True)
+
+    def test_8_levels_5_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=5, indices=True, multi_in=False)
+
+    def test_8_levels_5_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=5, indices=True, multi_in=True)
+
+    def test_8_levels_6_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=6, indices=False, multi_in=False)
+
+    def test_8_levels_6_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=6, indices=False, multi_in=True)
+
+    def test_8_levels_6_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=6, indices=True, multi_in=False)
+
+    def test_8_levels_6_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=6, indices=True, multi_in=True)
+
+    def test_8_levels_7_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=7, indices=False, multi_in=False)
+
+    def test_8_levels_7_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=7, indices=False, multi_in=True)
+
+    def test_8_levels_7_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=7, indices=True, multi_in=False)
+
+    def test_8_levels_7_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=7, indices=True, multi_in=True)
+
+    def test_8_levels_8_idxs_f_mi_f(self):
+        self.tst(wid=8, levels=8, indices=False, multi_in=False)
+
+    def test_8_levels_8_idxs_f_mi_t(self):
+        self.tst(wid=8, levels=8, indices=False, multi_in=True)
+
+    def test_8_levels_8_idxs_t_mi_f(self):
+        self.tst(wid=8, levels=8, indices=True, multi_in=False)
+
+    def test_8_levels_8_idxs_t_mi_t(self):
+        self.tst(wid=8, levels=8, indices=True, multi_in=True)
+
+
 if __name__ == "__main__":
     unittest.main()
index 7cf7f7bd03001f920ce898ab765211d7689832ad..ab56741d0d5c5ff56324d8b818e746729f6df5e7 100644 (file)
@@ -104,13 +104,13 @@ class MultiPriorityPicker(Elaboratable):
         Also outputted (optional): an index for each picked "thing".
     """
 
-    def __init__(self, wid, levels, indices=False, multiin=False):
+    def __init__(self, wid, levels, indices=False, multi_in=False):
         self.levels = levels
         self.wid = wid
         self.indices = indices
-        self.multiin = multiin
+        self.multi_in = multi_in
 
-        if multiin:
+        if multi_in:
             # multiple inputs, multiple outputs.
             i_l = []  # array of picker outputs
             for j in range(self.levels):
@@ -154,7 +154,7 @@ class MultiPriorityPicker(Elaboratable):
         p_mask = None
         pp_l = []
         for j in range(self.levels):
-            if self.multiin:
+            if self.multi_in:
                 i = self.i[j]
             else:
                 i = self.i
@@ -186,28 +186,25 @@ class MultiPriorityPicker(Elaboratable):
 
         # for each picker enabled, pass that out and set a cascading index
         lidx = math.ceil(math.log2(self.levels))
-        prev_count = None
+        prev_count = 0
         for j in range(self.levels):
             en_o = pp_l[j].en_o
-            if prev_count is None:
-                comb += self.idx_o[j].eq(0)
-            else:
-                count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
-                comb += count1.eq(prev_count + Const(1, lidx))
-                comb += self.idx_o[j].eq(Mux(en_o, count1, prev_count))
-            prev_count = self.idx_o[j]
+            count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
+            comb += count1.eq(prev_count + Const(1, lidx))
+            comb += self.idx_o[j].eq(prev_count)
+            prev_count = Mux(en_o, count1, prev_count)
 
         return m
 
     def __iter__(self):
-        if self.multiin:
+        if self.multi_in:
             yield from self.i
         else:
             yield self.i
         yield from self.o
+        yield self.en_o
         if not self.indices:
             return
-        yield self.en_o
         yield from self.idx_o
 
     def ports(self):