"""row in the prefix-sum diagram"""
-def prefix_sum_ops(item_count, *, work_efficient=False):
+def prefix_sum_ops(item_count, *, work_efficient=False, reduce_only=False):
""" Get the associative operations needed to compute a parallel prefix-sum
of `item_count` items.
True if the algorithm used should be work-efficient -- has a larger
depth (about twice as large) but does only `O(N)` operations total
instead of `O(N*log(N))`.
+ reduce_only: bool
+ True if the work-efficient algorithm should stop after the initial
+ tree-reduction step.
Returns: Iterable[Op]
output associative operations.
"""
yield Op(out=i, lhs=i - dist, rhs=i, row=row)
dist <<= 1
row += 1
- if work_efficient:
+ if work_efficient and not reduce_only:
# express all output items in terms of the computed partial sums.
dist >>= 1
while dist >= 1: