add formal proof for MultiPriorityPicker
[nmutil.git] / src / nmutil / picker.py
index 7cf7f7bd03001f920ce898ab765211d7689832ad..ab56741d0d5c5ff56324d8b818e746729f6df5e7 100644 (file)
@@ -104,13 +104,13 @@ class MultiPriorityPicker(Elaboratable):
         Also outputted (optional): an index for each picked "thing".
     """
 
-    def __init__(self, wid, levels, indices=False, multiin=False):
+    def __init__(self, wid, levels, indices=False, multi_in=False):
         self.levels = levels
         self.wid = wid
         self.indices = indices
-        self.multiin = multiin
+        self.multi_in = multi_in
 
-        if multiin:
+        if multi_in:
             # multiple inputs, multiple outputs.
             i_l = []  # array of picker outputs
             for j in range(self.levels):
@@ -154,7 +154,7 @@ class MultiPriorityPicker(Elaboratable):
         p_mask = None
         pp_l = []
         for j in range(self.levels):
-            if self.multiin:
+            if self.multi_in:
                 i = self.i[j]
             else:
                 i = self.i
@@ -186,28 +186,25 @@ class MultiPriorityPicker(Elaboratable):
 
         # for each picker enabled, pass that out and set a cascading index
         lidx = math.ceil(math.log2(self.levels))
-        prev_count = None
+        prev_count = 0
         for j in range(self.levels):
             en_o = pp_l[j].en_o
-            if prev_count is None:
-                comb += self.idx_o[j].eq(0)
-            else:
-                count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
-                comb += count1.eq(prev_count + Const(1, lidx))
-                comb += self.idx_o[j].eq(Mux(en_o, count1, prev_count))
-            prev_count = self.idx_o[j]
+            count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
+            comb += count1.eq(prev_count + Const(1, lidx))
+            comb += self.idx_o[j].eq(prev_count)
+            prev_count = Mux(en_o, count1, prev_count)
 
         return m
 
     def __iter__(self):
-        if self.multiin:
+        if self.multi_in:
             yield from self.i
         else:
             yield self.i
         yield from self.o
+        yield self.en_o
         if not self.indices:
             return
-        yield self.en_o
         yield from self.idx_o
 
     def ports(self):