add partial_prefix_sum_ops
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 5 Aug 2022 06:15:30 +0000 (23:15 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 5 Aug 2022 06:15:30 +0000 (23:15 -0700)
src/nmutil/prefix_sum.py

index 3908d1b5495c402059dca54040827e9fae941ee6..fc2f3d172a94d1706452b86e375e8497b018b4da 100644 (file)
@@ -220,11 +220,37 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     return "\n".join(map(str.rstrip, lines))
 
 
-def tree_reduction_ops(item_count):
-    assert isinstance(item_count, int) and item_count >= 1
-    ops = list(prefix_sum_ops(item_count=item_count))
-    items_live_flags = [False] * item_count
-    items_live_flags[-1] = True
+def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
+    """ Get the associative operations needed to compute a parallel prefix-sum
+    of `len(needed_outputs)` items.
+
+    The operations aren't assumed to be commutative.
+
+    This has a depth of `O(log(N))` and an operation count of `O(N)` if
+    `work_efficient` is true, otherwise `O(N*log(N))`.
+
+    The algorithms used are derived from:
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+    Parameters:
+    needed_outputs: Iterable[bool]
+        The length is the number of input/output items.
+        Each item is True if that corresponding output is needed.
+        Unneeded outputs have unspecified value.
+    work_efficient: bool
+        True if the algorithm used should be work-efficient -- has a larger
+        depth (about twice as large) but does only `O(N)` operations total
+        instead of `O(N*log(N))`.
+    Returns: Iterable[Op]
+        output associative operations.
+    """
+    def assert_bool(v):
+        assert isinstance(v, bool)
+        return v
+    items_live_flags = [assert_bool(i) for i in needed_outputs]
+    ops = list(prefix_sum_ops(item_count=len(items_live_flags),
+                              work_efficient=work_efficient))
     ops_live_flags = [False] * len(ops)
     for i in reversed(range(len(ops))):
         op = ops[i]
@@ -238,6 +264,12 @@ def tree_reduction_ops(item_count):
             yield op
 
 
+def tree_reduction_ops(item_count):
+    assert isinstance(item_count, int) and item_count >= 1
+    needed_outputs = (i == item_count - 1 for i in range(item_count))
+    return partial_prefix_sum_ops(needed_outputs)
+
+
 def tree_reduction(items, fn=operator.add):
     items = list(items)
     for op in tree_reduction_ops(len(items)):