From: Jacob Lifshay Date: Thu, 4 Aug 2022 05:41:28 +0000 (-0700) Subject: add BetterMultiPriorityPicker and formal proof X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=e375a4fa59d1e38ecedde4c9fe6155d6023e6742;p=nmutil.git add BetterMultiPriorityPicker and formal proof --- diff --git a/src/nmutil/formal/test_picker.py b/src/nmutil/formal/test_picker.py index 1abf56f..caaf007 100644 --- a/src/nmutil/formal/test_picker.py +++ b/src/nmutil/formal/test_picker.py @@ -7,7 +7,9 @@ import unittest from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, Array, Shape, Mux from nmigen.hdl.dsl import Module from nmutil.formaltest import FHDLTestCase -from nmutil.picker import PriorityPicker, MultiPriorityPicker +from nmutil.picker import (BetterMultiPriorityPicker, PriorityPicker, + MultiPriorityPicker) +from nmutil.sim_util import write_il class TestPriorityPicker(FHDLTestCase): @@ -227,61 +229,65 @@ class TestPriorityPicker(FHDLTestCase): class TestMultiPriorityPicker(FHDLTestCase): - def tst(self, *, cls=MultiPriorityPicker, wid, levels, indices, multi_in): - assert isinstance(wid, int) and wid >= 1 - assert isinstance(levels, int) and 1 <= levels <= wid - assert isinstance(indices, bool) - assert isinstance(multi_in, bool) - dut = cls(wid=wid, levels=levels, indices=indices, multi_in=multi_in) - self.assertEqual(wid, dut.wid) + def make_dut(self, width, levels, indices, multi_in): + dut = MultiPriorityPicker(wid=width, levels=levels, indices=indices, + multi_in=multi_in) + self.assertEqual(width, dut.wid) self.assertEqual(levels, dut.levels) self.assertEqual(indices, dut.indices) self.assertEqual(multi_in, dut.multi_in) + return dut + + def tst(self, *, width, levels, indices, multi_in): + assert isinstance(width, int) and width >= 1 + assert isinstance(levels, int) and 1 <= levels <= width + assert isinstance(indices, bool) + assert isinstance(multi_in, bool) + dut = self.make_dut(width=width, levels=levels, indices=indices, + multi_in=multi_in) expected_ports = [] if multi_in: - self.assertIsInstance(dut.i, Array) + self.assertIsInstance(dut.i, (Array, list)) self.assertEqual(len(dut.i), levels) for i in dut.i: self.assertIsInstance(i, Signal) - self.assertEqual(len(i), wid) + self.assertEqual(len(i), width) expected_ports.append(i) else: self.assertIsInstance(dut.i, Signal) - self.assertEqual(len(dut.i), wid) + self.assertEqual(len(dut.i), width) expected_ports.append(dut.i) - self.assertIsInstance(dut.o, Array) + self.assertIsInstance(dut.o, (Array, list)) self.assertEqual(len(dut.o), levels) for o in dut.o: self.assertIsInstance(o, Signal) - self.assertEqual(len(o), wid) + self.assertEqual(len(o), width) expected_ports.append(o) self.assertEqual(len(dut.en_o), levels) expected_ports.append(dut.en_o) if indices: - expected_idx_o_shape = Shape.cast(range(levels)) - if levels <= 1: - expected_idx_o_shape = Shape(0, False) - self.assertIsInstance(dut.idx_o, Array) + self.assertIsInstance(dut.idx_o, (Array, list)) self.assertEqual(len(dut.idx_o), levels) for idx_o in dut.idx_o: self.assertIsInstance(idx_o, Signal) - self.assertEqual(idx_o.shape(), expected_idx_o_shape) expected_ports.append(idx_o) else: self.assertFalse(hasattr(dut, "idx_o")) self.assertListEqual(expected_ports, dut.ports()) + write_il(self, dut, ports=dut.ports()) + m = Module() m.submodules.dut = dut if multi_in: for i in dut.i: - m.d.comb += i.eq(AnyConst(wid)) + m.d.comb += i.eq(AnyConst(width)) else: - m.d.comb += dut.i.eq(AnyConst(wid)) + m.d.comb += dut.i.eq(AnyConst(width)) prev_set = 0 for o, en_o in zip(dut.o, dut.en_o): @@ -293,14 +299,14 @@ class TestMultiPriorityPicker(FHDLTestCase): m.d.comb += Assert((o != 0) == en_o) - prev_set = Const(0, wid) - priority_pickers = [PriorityPicker(wid) for _ in range(levels)] + prev_set = Const(0, width) + priority_pickers = [PriorityPicker(width) for _ in range(levels)] for level in range(levels): pp = priority_pickers[level] setattr(m.submodules, f"pp_{level}", pp) inp = dut.i[level] if multi_in else dut.i m.d.comb += pp.i.eq(inp & ~prev_set) - cur_set = Signal(wid, name=f"cur_set_{level}") + cur_set = Signal(width, name=f"cur_set_{level}") m.d.comb += cur_set.eq(prev_set | pp.o) prev_set = cur_set m.d.comb += Assert(pp.o == dut.o[level]) @@ -314,148 +320,165 @@ class TestMultiPriorityPicker(FHDLTestCase): self.assertFormal(m) def test_4_levels_1_idxs_f_mi_f(self): - self.tst(wid=4, levels=1, indices=False, multi_in=False) + self.tst(width=4, levels=1, indices=False, multi_in=False) def test_4_levels_1_idxs_f_mi_t(self): - self.tst(wid=4, levels=1, indices=False, multi_in=True) + self.tst(width=4, levels=1, indices=False, multi_in=True) def test_4_levels_1_idxs_t_mi_f(self): - self.tst(wid=4, levels=1, indices=True, multi_in=False) + self.tst(width=4, levels=1, indices=True, multi_in=False) def test_4_levels_1_idxs_t_mi_t(self): - self.tst(wid=4, levels=1, indices=True, multi_in=True) + self.tst(width=4, levels=1, indices=True, multi_in=True) def test_4_levels_2_idxs_f_mi_f(self): - self.tst(wid=4, levels=2, indices=False, multi_in=False) + self.tst(width=4, levels=2, indices=False, multi_in=False) def test_4_levels_2_idxs_f_mi_t(self): - self.tst(wid=4, levels=2, indices=False, multi_in=True) + self.tst(width=4, levels=2, indices=False, multi_in=True) def test_4_levels_2_idxs_t_mi_f(self): - self.tst(wid=4, levels=2, indices=True, multi_in=False) + self.tst(width=4, levels=2, indices=True, multi_in=False) def test_4_levels_2_idxs_t_mi_t(self): - self.tst(wid=4, levels=2, indices=True, multi_in=True) + self.tst(width=4, levels=2, indices=True, multi_in=True) def test_4_levels_3_idxs_f_mi_f(self): - self.tst(wid=4, levels=3, indices=False, multi_in=False) + self.tst(width=4, levels=3, indices=False, multi_in=False) def test_4_levels_3_idxs_f_mi_t(self): - self.tst(wid=4, levels=3, indices=False, multi_in=True) + self.tst(width=4, levels=3, indices=False, multi_in=True) def test_4_levels_3_idxs_t_mi_f(self): - self.tst(wid=4, levels=3, indices=True, multi_in=False) + self.tst(width=4, levels=3, indices=True, multi_in=False) def test_4_levels_3_idxs_t_mi_t(self): - self.tst(wid=4, levels=3, indices=True, multi_in=True) + self.tst(width=4, levels=3, indices=True, multi_in=True) def test_4_levels_4_idxs_f_mi_f(self): - self.tst(wid=4, levels=4, indices=False, multi_in=False) + self.tst(width=4, levels=4, indices=False, multi_in=False) def test_4_levels_4_idxs_f_mi_t(self): - self.tst(wid=4, levels=4, indices=False, multi_in=True) + self.tst(width=4, levels=4, indices=False, multi_in=True) def test_4_levels_4_idxs_t_mi_f(self): - self.tst(wid=4, levels=4, indices=True, multi_in=False) + self.tst(width=4, levels=4, indices=True, multi_in=False) def test_4_levels_4_idxs_t_mi_t(self): - self.tst(wid=4, levels=4, indices=True, multi_in=True) + self.tst(width=4, levels=4, indices=True, multi_in=True) def test_8_levels_1_idxs_f_mi_f(self): - self.tst(wid=8, levels=1, indices=False, multi_in=False) + self.tst(width=8, levels=1, indices=False, multi_in=False) def test_8_levels_1_idxs_f_mi_t(self): - self.tst(wid=8, levels=1, indices=False, multi_in=True) + self.tst(width=8, levels=1, indices=False, multi_in=True) def test_8_levels_1_idxs_t_mi_f(self): - self.tst(wid=8, levels=1, indices=True, multi_in=False) + self.tst(width=8, levels=1, indices=True, multi_in=False) def test_8_levels_1_idxs_t_mi_t(self): - self.tst(wid=8, levels=1, indices=True, multi_in=True) + self.tst(width=8, levels=1, indices=True, multi_in=True) def test_8_levels_2_idxs_f_mi_f(self): - self.tst(wid=8, levels=2, indices=False, multi_in=False) + self.tst(width=8, levels=2, indices=False, multi_in=False) def test_8_levels_2_idxs_f_mi_t(self): - self.tst(wid=8, levels=2, indices=False, multi_in=True) + self.tst(width=8, levels=2, indices=False, multi_in=True) def test_8_levels_2_idxs_t_mi_f(self): - self.tst(wid=8, levels=2, indices=True, multi_in=False) + self.tst(width=8, levels=2, indices=True, multi_in=False) def test_8_levels_2_idxs_t_mi_t(self): - self.tst(wid=8, levels=2, indices=True, multi_in=True) + self.tst(width=8, levels=2, indices=True, multi_in=True) def test_8_levels_3_idxs_f_mi_f(self): - self.tst(wid=8, levels=3, indices=False, multi_in=False) + self.tst(width=8, levels=3, indices=False, multi_in=False) def test_8_levels_3_idxs_f_mi_t(self): - self.tst(wid=8, levels=3, indices=False, multi_in=True) + self.tst(width=8, levels=3, indices=False, multi_in=True) def test_8_levels_3_idxs_t_mi_f(self): - self.tst(wid=8, levels=3, indices=True, multi_in=False) + self.tst(width=8, levels=3, indices=True, multi_in=False) def test_8_levels_3_idxs_t_mi_t(self): - self.tst(wid=8, levels=3, indices=True, multi_in=True) + self.tst(width=8, levels=3, indices=True, multi_in=True) def test_8_levels_4_idxs_f_mi_f(self): - self.tst(wid=8, levels=4, indices=False, multi_in=False) + self.tst(width=8, levels=4, indices=False, multi_in=False) def test_8_levels_4_idxs_f_mi_t(self): - self.tst(wid=8, levels=4, indices=False, multi_in=True) + self.tst(width=8, levels=4, indices=False, multi_in=True) def test_8_levels_4_idxs_t_mi_f(self): - self.tst(wid=8, levels=4, indices=True, multi_in=False) + self.tst(width=8, levels=4, indices=True, multi_in=False) def test_8_levels_4_idxs_t_mi_t(self): - self.tst(wid=8, levels=4, indices=True, multi_in=True) + self.tst(width=8, levels=4, indices=True, multi_in=True) def test_8_levels_5_idxs_f_mi_f(self): - self.tst(wid=8, levels=5, indices=False, multi_in=False) + self.tst(width=8, levels=5, indices=False, multi_in=False) def test_8_levels_5_idxs_f_mi_t(self): - self.tst(wid=8, levels=5, indices=False, multi_in=True) + self.tst(width=8, levels=5, indices=False, multi_in=True) def test_8_levels_5_idxs_t_mi_f(self): - self.tst(wid=8, levels=5, indices=True, multi_in=False) + self.tst(width=8, levels=5, indices=True, multi_in=False) def test_8_levels_5_idxs_t_mi_t(self): - self.tst(wid=8, levels=5, indices=True, multi_in=True) + self.tst(width=8, levels=5, indices=True, multi_in=True) def test_8_levels_6_idxs_f_mi_f(self): - self.tst(wid=8, levels=6, indices=False, multi_in=False) + self.tst(width=8, levels=6, indices=False, multi_in=False) def test_8_levels_6_idxs_f_mi_t(self): - self.tst(wid=8, levels=6, indices=False, multi_in=True) + self.tst(width=8, levels=6, indices=False, multi_in=True) def test_8_levels_6_idxs_t_mi_f(self): - self.tst(wid=8, levels=6, indices=True, multi_in=False) + self.tst(width=8, levels=6, indices=True, multi_in=False) def test_8_levels_6_idxs_t_mi_t(self): - self.tst(wid=8, levels=6, indices=True, multi_in=True) + self.tst(width=8, levels=6, indices=True, multi_in=True) def test_8_levels_7_idxs_f_mi_f(self): - self.tst(wid=8, levels=7, indices=False, multi_in=False) + self.tst(width=8, levels=7, indices=False, multi_in=False) def test_8_levels_7_idxs_f_mi_t(self): - self.tst(wid=8, levels=7, indices=False, multi_in=True) + self.tst(width=8, levels=7, indices=False, multi_in=True) def test_8_levels_7_idxs_t_mi_f(self): - self.tst(wid=8, levels=7, indices=True, multi_in=False) + self.tst(width=8, levels=7, indices=True, multi_in=False) def test_8_levels_7_idxs_t_mi_t(self): - self.tst(wid=8, levels=7, indices=True, multi_in=True) + self.tst(width=8, levels=7, indices=True, multi_in=True) def test_8_levels_8_idxs_f_mi_f(self): - self.tst(wid=8, levels=8, indices=False, multi_in=False) + self.tst(width=8, levels=8, indices=False, multi_in=False) def test_8_levels_8_idxs_f_mi_t(self): - self.tst(wid=8, levels=8, indices=False, multi_in=True) + self.tst(width=8, levels=8, indices=False, multi_in=True) def test_8_levels_8_idxs_t_mi_f(self): - self.tst(wid=8, levels=8, indices=True, multi_in=False) + self.tst(width=8, levels=8, indices=True, multi_in=False) def test_8_levels_8_idxs_t_mi_t(self): - self.tst(wid=8, levels=8, indices=True, multi_in=True) + self.tst(width=8, levels=8, indices=True, multi_in=True) + + def test_16_levels_16_idxs_f_mi_f(self): + self.tst(width=16, levels=16, indices=False, multi_in=False) + + +class TestBetterMultiPriorityPicker(TestMultiPriorityPicker): + def make_dut(self, width, levels, indices, multi_in): + if multi_in: + self.skipTest( + "multi_in are not supported by BetterMultiPriorityPicker") + if indices: + self.skipTest( + "indices are not supported by BetterMultiPriorityPicker") + dut = BetterMultiPriorityPicker(width=width, levels=levels) + self.assertEqual(width, dut.width) + self.assertEqual(levels, dut.levels) + return dut if __name__ == "__main__": diff --git a/src/nmutil/picker.py b/src/nmutil/picker.py index ab56741..7aab175 100644 --- a/src/nmutil/picker.py +++ b/src/nmutil/picker.py @@ -26,8 +26,10 @@ """ from nmigen import Module, Signal, Cat, Elaboratable, Array, Const, Mux +from nmigen.utils import bits_for from nmigen.cli import rtlil import math +from nmutil.prefix_sum import prefix_sum class PriorityPicker(Elaboratable): @@ -211,6 +213,64 @@ class MultiPriorityPicker(Elaboratable): return list(self) +class BetterMultiPriorityPicker(Elaboratable): + """A better replacement for MultiPriorityPicker that has O(log levels) + latency, rather than > O(levels) latency. + """ + + def __init__(self, width, levels, *, work_efficient=False): + assert isinstance(width, int) and width >= 1 + assert isinstance(levels, int) and 1 <= levels <= width + assert isinstance(work_efficient, bool) + self.width = width + self.levels = levels + self.work_efficient = work_efficient + assert self.__index_sat > self.levels - 1 + self.i = Signal(width) + self.o = [Signal(width, name=f"o_{i}") for i in range(levels)] + self.en_o = Signal(levels) + + @property + def __index_width(self): + return bits_for(self.levels) + + @property + def __index_sat(self): + return (1 << self.__index_width) - 1 + + def elaborate(self, platform): + m = Module() + + def sat_add(a, b): + sum = Signal(self.__index_width + 1) + m.d.comb += sum.eq(a + b) + retval = Signal(self.__index_width) + m.d.comb += retval.eq(Mux(sum[-1], self.__index_sat, sum)) + return retval + indexes = prefix_sum((self.i[i] for i in range(self.width - 1)), + sat_add, work_efficient=self.work_efficient) + indexes.insert(0, 0) + for i in range(self.width): + sig = Signal(self.__index_width, name=f"index_{i}") + m.d.comb += sig.eq(indexes[i]) + indexes[i] = sig + for level in range(self.levels): + m.d.comb += self.en_o[level].eq(self.o[level].bool()) + for i in range(self.width): + index_matches = indexes[i] == level + m.d.comb += self.o[level][i].eq(index_matches & self.i[i]) + + return m + + def __iter__(self): + yield self.i + yield from self.o + yield self.en_o + + def ports(self): + return list(self) + + if __name__ == '__main__': dut = PriorityPicker(16) vl = rtlil.convert(dut, ports=dut.ports())