From e91a7d4caaeaf3e870a6ed5223d5241298751217 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 3 Aug 2022 21:09:53 -0700 Subject: [PATCH] add formal proof for MultiPriorityPicker --- src/nmutil/formal/test_picker.py | 236 ++++++++++++++++++++++++++++++- src/nmutil/picker.py | 25 ++-- 2 files changed, 246 insertions(+), 15 deletions(-) diff --git a/src/nmutil/formal/test_picker.py b/src/nmutil/formal/test_picker.py index 573129c..1abf56f 100644 --- a/src/nmutil/formal/test_picker.py +++ b/src/nmutil/formal/test_picker.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3-or-later # Copyright 2022 Jacob Lifshay +from functools import reduce +import operator import unittest -from nmigen.hdl.ast import AnyConst, Assert, Signal, Const +from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, Array, Shape, Mux from nmigen.hdl.dsl import Module from nmutil.formaltest import FHDLTestCase from nmutil.picker import PriorityPicker, MultiPriorityPicker @@ -224,5 +226,237 @@ class TestPriorityPicker(FHDLTestCase): self.tst(wid=64, msb_mode=True, reverse_i=True, reverse_o=True) +class TestMultiPriorityPicker(FHDLTestCase): + def tst(self, *, cls=MultiPriorityPicker, wid, levels, indices, multi_in): + assert isinstance(wid, int) and wid >= 1 + assert isinstance(levels, int) and 1 <= levels <= wid + assert isinstance(indices, bool) + assert isinstance(multi_in, bool) + dut = cls(wid=wid, levels=levels, indices=indices, multi_in=multi_in) + self.assertEqual(wid, dut.wid) + self.assertEqual(levels, dut.levels) + self.assertEqual(indices, dut.indices) + self.assertEqual(multi_in, dut.multi_in) + expected_ports = [] + if multi_in: + self.assertIsInstance(dut.i, Array) + self.assertEqual(len(dut.i), levels) + for i in dut.i: + self.assertIsInstance(i, Signal) + self.assertEqual(len(i), wid) + expected_ports.append(i) + else: + self.assertIsInstance(dut.i, Signal) + self.assertEqual(len(dut.i), wid) + expected_ports.append(dut.i) + + self.assertIsInstance(dut.o, Array) + self.assertEqual(len(dut.o), levels) + for o in dut.o: + self.assertIsInstance(o, Signal) + self.assertEqual(len(o), wid) + expected_ports.append(o) + + self.assertEqual(len(dut.en_o), levels) + expected_ports.append(dut.en_o) + + if indices: + expected_idx_o_shape = Shape.cast(range(levels)) + if levels <= 1: + expected_idx_o_shape = Shape(0, False) + self.assertIsInstance(dut.idx_o, Array) + self.assertEqual(len(dut.idx_o), levels) + for idx_o in dut.idx_o: + self.assertIsInstance(idx_o, Signal) + self.assertEqual(idx_o.shape(), expected_idx_o_shape) + expected_ports.append(idx_o) + else: + self.assertFalse(hasattr(dut, "idx_o")) + + self.assertListEqual(expected_ports, dut.ports()) + + m = Module() + m.submodules.dut = dut + if multi_in: + for i in dut.i: + m.d.comb += i.eq(AnyConst(wid)) + else: + m.d.comb += dut.i.eq(AnyConst(wid)) + + prev_set = 0 + for o, en_o in zip(dut.o, dut.en_o): + # assert o only has zero or one bit set + m.d.comb += Assert((o & (o - 1)) == 0) + # assert o doesn't overlap any previous outputs + m.d.comb += Assert((o & prev_set) == 0) + prev_set |= o + + m.d.comb += Assert((o != 0) == en_o) + + prev_set = Const(0, wid) + priority_pickers = [PriorityPicker(wid) for _ in range(levels)] + for level in range(levels): + pp = priority_pickers[level] + setattr(m.submodules, f"pp_{level}", pp) + inp = dut.i[level] if multi_in else dut.i + m.d.comb += pp.i.eq(inp & ~prev_set) + cur_set = Signal(wid, name=f"cur_set_{level}") + m.d.comb += cur_set.eq(prev_set | pp.o) + prev_set = cur_set + m.d.comb += Assert(pp.o == dut.o[level]) + expected_idx = Signal(32, name=f"expected_idx_{level}") + number_of_prev_en_o_set = reduce( + operator.add, (i.en_o for i in priority_pickers[:level]), 0) + m.d.comb += expected_idx.eq(number_of_prev_en_o_set) + if indices: + m.d.comb += Assert(expected_idx == dut.idx_o[level]) + + self.assertFormal(m) + + def test_4_levels_1_idxs_f_mi_f(self): + self.tst(wid=4, levels=1, indices=False, multi_in=False) + + def test_4_levels_1_idxs_f_mi_t(self): + self.tst(wid=4, levels=1, indices=False, multi_in=True) + + def test_4_levels_1_idxs_t_mi_f(self): + self.tst(wid=4, levels=1, indices=True, multi_in=False) + + def test_4_levels_1_idxs_t_mi_t(self): + self.tst(wid=4, levels=1, indices=True, multi_in=True) + + def test_4_levels_2_idxs_f_mi_f(self): + self.tst(wid=4, levels=2, indices=False, multi_in=False) + + def test_4_levels_2_idxs_f_mi_t(self): + self.tst(wid=4, levels=2, indices=False, multi_in=True) + + def test_4_levels_2_idxs_t_mi_f(self): + self.tst(wid=4, levels=2, indices=True, multi_in=False) + + def test_4_levels_2_idxs_t_mi_t(self): + self.tst(wid=4, levels=2, indices=True, multi_in=True) + + def test_4_levels_3_idxs_f_mi_f(self): + self.tst(wid=4, levels=3, indices=False, multi_in=False) + + def test_4_levels_3_idxs_f_mi_t(self): + self.tst(wid=4, levels=3, indices=False, multi_in=True) + + def test_4_levels_3_idxs_t_mi_f(self): + self.tst(wid=4, levels=3, indices=True, multi_in=False) + + def test_4_levels_3_idxs_t_mi_t(self): + self.tst(wid=4, levels=3, indices=True, multi_in=True) + + def test_4_levels_4_idxs_f_mi_f(self): + self.tst(wid=4, levels=4, indices=False, multi_in=False) + + def test_4_levels_4_idxs_f_mi_t(self): + self.tst(wid=4, levels=4, indices=False, multi_in=True) + + def test_4_levels_4_idxs_t_mi_f(self): + self.tst(wid=4, levels=4, indices=True, multi_in=False) + + def test_4_levels_4_idxs_t_mi_t(self): + self.tst(wid=4, levels=4, indices=True, multi_in=True) + + def test_8_levels_1_idxs_f_mi_f(self): + self.tst(wid=8, levels=1, indices=False, multi_in=False) + + def test_8_levels_1_idxs_f_mi_t(self): + self.tst(wid=8, levels=1, indices=False, multi_in=True) + + def test_8_levels_1_idxs_t_mi_f(self): + self.tst(wid=8, levels=1, indices=True, multi_in=False) + + def test_8_levels_1_idxs_t_mi_t(self): + self.tst(wid=8, levels=1, indices=True, multi_in=True) + + def test_8_levels_2_idxs_f_mi_f(self): + self.tst(wid=8, levels=2, indices=False, multi_in=False) + + def test_8_levels_2_idxs_f_mi_t(self): + self.tst(wid=8, levels=2, indices=False, multi_in=True) + + def test_8_levels_2_idxs_t_mi_f(self): + self.tst(wid=8, levels=2, indices=True, multi_in=False) + + def test_8_levels_2_idxs_t_mi_t(self): + self.tst(wid=8, levels=2, indices=True, multi_in=True) + + def test_8_levels_3_idxs_f_mi_f(self): + self.tst(wid=8, levels=3, indices=False, multi_in=False) + + def test_8_levels_3_idxs_f_mi_t(self): + self.tst(wid=8, levels=3, indices=False, multi_in=True) + + def test_8_levels_3_idxs_t_mi_f(self): + self.tst(wid=8, levels=3, indices=True, multi_in=False) + + def test_8_levels_3_idxs_t_mi_t(self): + self.tst(wid=8, levels=3, indices=True, multi_in=True) + + def test_8_levels_4_idxs_f_mi_f(self): + self.tst(wid=8, levels=4, indices=False, multi_in=False) + + def test_8_levels_4_idxs_f_mi_t(self): + self.tst(wid=8, levels=4, indices=False, multi_in=True) + + def test_8_levels_4_idxs_t_mi_f(self): + self.tst(wid=8, levels=4, indices=True, multi_in=False) + + def test_8_levels_4_idxs_t_mi_t(self): + self.tst(wid=8, levels=4, indices=True, multi_in=True) + + def test_8_levels_5_idxs_f_mi_f(self): + self.tst(wid=8, levels=5, indices=False, multi_in=False) + + def test_8_levels_5_idxs_f_mi_t(self): + self.tst(wid=8, levels=5, indices=False, multi_in=True) + + def test_8_levels_5_idxs_t_mi_f(self): + self.tst(wid=8, levels=5, indices=True, multi_in=False) + + def test_8_levels_5_idxs_t_mi_t(self): + self.tst(wid=8, levels=5, indices=True, multi_in=True) + + def test_8_levels_6_idxs_f_mi_f(self): + self.tst(wid=8, levels=6, indices=False, multi_in=False) + + def test_8_levels_6_idxs_f_mi_t(self): + self.tst(wid=8, levels=6, indices=False, multi_in=True) + + def test_8_levels_6_idxs_t_mi_f(self): + self.tst(wid=8, levels=6, indices=True, multi_in=False) + + def test_8_levels_6_idxs_t_mi_t(self): + self.tst(wid=8, levels=6, indices=True, multi_in=True) + + def test_8_levels_7_idxs_f_mi_f(self): + self.tst(wid=8, levels=7, indices=False, multi_in=False) + + def test_8_levels_7_idxs_f_mi_t(self): + self.tst(wid=8, levels=7, indices=False, multi_in=True) + + def test_8_levels_7_idxs_t_mi_f(self): + self.tst(wid=8, levels=7, indices=True, multi_in=False) + + def test_8_levels_7_idxs_t_mi_t(self): + self.tst(wid=8, levels=7, indices=True, multi_in=True) + + def test_8_levels_8_idxs_f_mi_f(self): + self.tst(wid=8, levels=8, indices=False, multi_in=False) + + def test_8_levels_8_idxs_f_mi_t(self): + self.tst(wid=8, levels=8, indices=False, multi_in=True) + + def test_8_levels_8_idxs_t_mi_f(self): + self.tst(wid=8, levels=8, indices=True, multi_in=False) + + def test_8_levels_8_idxs_t_mi_t(self): + self.tst(wid=8, levels=8, indices=True, multi_in=True) + + if __name__ == "__main__": unittest.main() diff --git a/src/nmutil/picker.py b/src/nmutil/picker.py index 7cf7f7b..ab56741 100644 --- a/src/nmutil/picker.py +++ b/src/nmutil/picker.py @@ -104,13 +104,13 @@ class MultiPriorityPicker(Elaboratable): Also outputted (optional): an index for each picked "thing". """ - def __init__(self, wid, levels, indices=False, multiin=False): + def __init__(self, wid, levels, indices=False, multi_in=False): self.levels = levels self.wid = wid self.indices = indices - self.multiin = multiin + self.multi_in = multi_in - if multiin: + if multi_in: # multiple inputs, multiple outputs. i_l = [] # array of picker outputs for j in range(self.levels): @@ -154,7 +154,7 @@ class MultiPriorityPicker(Elaboratable): p_mask = None pp_l = [] for j in range(self.levels): - if self.multiin: + if self.multi_in: i = self.i[j] else: i = self.i @@ -186,28 +186,25 @@ class MultiPriorityPicker(Elaboratable): # for each picker enabled, pass that out and set a cascading index lidx = math.ceil(math.log2(self.levels)) - prev_count = None + prev_count = 0 for j in range(self.levels): en_o = pp_l[j].en_o - if prev_count is None: - comb += self.idx_o[j].eq(0) - else: - count1 = Signal(lidx, name="count_%d" % j, reset_less=True) - comb += count1.eq(prev_count + Const(1, lidx)) - comb += self.idx_o[j].eq(Mux(en_o, count1, prev_count)) - prev_count = self.idx_o[j] + count1 = Signal(lidx, name="count_%d" % j, reset_less=True) + comb += count1.eq(prev_count + Const(1, lidx)) + comb += self.idx_o[j].eq(prev_count) + prev_count = Mux(en_o, count1, prev_count) return m def __iter__(self): - if self.multiin: + if self.multi_in: yield from self.i else: yield self.i yield from self.o + yield self.en_o if not self.indices: return - yield self.en_o yield from self.idx_o def ports(self): -- 2.30.2