fix prefix_sum.py after 63ffb1aa and d7288021
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 12 Aug 2022 06:37:02 +0000 (23:37 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 12 Aug 2022 06:37:02 +0000 (23:37 -0700)
src/nmutil/prefix_sum.py

index 549d772d78dc25aaa13aa8cf003e8cc90416971b..65f2b33eaf079889590d5aa13ffb6fdb28f1628c 100644 (file)
@@ -7,18 +7,29 @@
 from collections import defaultdict
 import operator
 from nmigen.hdl.ast import Value, Const
+from nmutil.plain_data import plain_data
 
 
+@plain_data(order=True, unsafe_hash=True, frozen=True)
 class Op:
     """An associative operation in a prefix-sum.
     The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
     The operation is not assumed to be commutative.
     """
-    def __init__(self,*, out, lhs, rhs, row):
-        self.out = out; "index of the item to output to"
-        self.lhs = lhs; "index of item the left-hand-side input comes from"
-        self.rhs = rhs; "index of item the right-hand-side input comes from"
-        self.row = row; "row in the prefix-sum diagram"
+    __slots__ = "out", "lhs", "rhs", "row"
+
+    def __init__(self, out, lhs, rhs, row):
+        self.out = out
+        """index of the item to output to"""
+
+        self.lhs = lhs
+        """index of the item the left-hand-side input comes from"""
+
+        self.rhs = rhs
+        """index of the item the right-hand-side input comes from"""
+
+        self.row = row
+        """row in the prefix-sum diagram"""
 
 
 def prefix_sum_ops(item_count, *, work_efficient=False):
@@ -44,8 +55,9 @@ def prefix_sum_ops(item_count, *, work_efficient=False):
     Returns: Iterable[Op]
         output associative operations.
     """
+    assert isinstance(item_count, int)
     # compute the partial sums using a set of binary trees
-    # first half of the work-efficient algorithm and the whole of
+    # this is the first half of the work-efficient algorithm and the whole of
     # the non-work-efficient algorithm.
     dist = 1
     row = 0
@@ -96,8 +108,11 @@ def prefix_sum(items, fn=operator.add, *, work_efficient=False):
     return items
 
 
+@plain_data()
 class _Cell:
-    def __init__(self, *, slant, plus, tee):
+    __slots__ = "slant", "plus", "tee"
+
+    def __init__(self, slant, plus, tee):
         self.slant = slant
         self.plus = plus
         self.tee = tee
@@ -138,6 +153,8 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     """
     ops_by_row = defaultdict(set)
     for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
+        assert op.out == op.rhs, f"can't draw op: {op}"
+        assert op not in ops_by_row[op.row], f"duplicate op: {op}"
         ops_by_row[op.row].add(op)
 
     def blank_row():
@@ -151,6 +168,7 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
         max_distance = max(op.rhs - op.lhs for op in ops)
         cells.extend(blank_row() for _ in range(max_distance))
         for op in ops:
+            assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
             y = len(cells) - 1
             x = op.out
             cells[y][x].plus = True
@@ -234,7 +252,10 @@ def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
     Returns: Iterable[Op]
         output associative operations.
     """
-    items_live_flags = needed_outputs
+
+    # needed_outputs is an iterable, we need to construct a new list so we
+    # don't modify the passed-in value
+    items_live_flags = [bool(i) for i in needed_outputs]
     ops = list(prefix_sum_ops(item_count=len(items_live_flags),
                               work_efficient=work_efficient))
     ops_live_flags = [False] * len(ops)
@@ -251,6 +272,7 @@ def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
 
 
 def tree_reduction_ops(item_count):
+    assert item_count >= 1
     needed_outputs = (i == item_count - 1 for i in range(item_count))
     return partial_prefix_sum_ops(needed_outputs)
 
@@ -263,13 +285,36 @@ def tree_reduction(items, fn=operator.add):
 
 
 def pop_count(v, *, width=None, process_temporary=lambda v: v):
+    """ return the population count (number of 1 bits) of `v`.
+    Arguments:
+    v: nmigen.Value | int
+        the value to calculate the pop-count of.
+    width: int | None
+        the bit-width of `v`.
+        If `width` is None, then `v` must be a nmigen Value or
+        match `v`'s width.
+    process_temporary: function of (type(v)) -> type(v)
+        called after every addition operation, can be used to introduce
+        `Signal`s for the intermediate values in the pop-count computation
+        like so:
+
+        ```
+        def process_temporary(v):
+            sig = Signal.like(v)
+            m.d.comb += sig.eq(v)
+            return sig
+        ```
+    """
     if isinstance(v, Value):
         if width is None:
             width = len(v)
+        assert width == len(v)
         bits = [v[i] for i in range(width)]
         if len(bits) == 0:
             return Const(0)
     else:
+        assert width is not None, "width must be given"
+        # v and width are ints
         bits = [(v & (1 << i)) != 0 for i in range(width)]
         if len(bits) == 0:
             return 0