return "\n".join(map(str.rstrip, lines))
-def tree_reduction_ops(item_count):
- assert isinstance(item_count, int) and item_count >= 1
- ops = list(prefix_sum_ops(item_count=item_count))
- items_live_flags = [False] * item_count
- items_live_flags[-1] = True
+def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
+ """ Get the associative operations needed to compute a parallel prefix-sum
+ of `len(needed_outputs)` items.
+
+ The operations aren't assumed to be commutative.
+
+ This has a depth of `O(log(N))` and an operation count of `O(N)` if
+ `work_efficient` is true, otherwise `O(N*log(N))`.
+
+ The algorithms used are derived from:
+ https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+ https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+ Parameters:
+ needed_outputs: Iterable[bool]
+ The length is the number of input/output items.
+ Each item is True if that corresponding output is needed.
+ Unneeded outputs have unspecified value.
+ work_efficient: bool
+ True if the algorithm used should be work-efficient -- has a larger
+ depth (about twice as large) but does only `O(N)` operations total
+ instead of `O(N*log(N))`.
+ Returns: Iterable[Op]
+ output associative operations.
+ """
+ def assert_bool(v):
+ assert isinstance(v, bool)
+ return v
+ items_live_flags = [assert_bool(i) for i in needed_outputs]
+ ops = list(prefix_sum_ops(item_count=len(items_live_flags),
+ work_efficient=work_efficient))
ops_live_flags = [False] * len(ops)
for i in reversed(range(len(ops))):
op = ops[i]
yield op
+def tree_reduction_ops(item_count):
+ assert isinstance(item_count, int) and item_count >= 1
+ needed_outputs = (i == item_count - 1 for i in range(item_count))
+ return partial_prefix_sum_ops(needed_outputs)
+
+
def tree_reduction(items, fn=operator.add):
items = list(items)
for op in tree_reduction_ops(len(items)):