1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
8 from dataclasses
import dataclass
12 @dataclass(order
=True, unsafe_hash
=True, frozen
=True)
14 """An associative operation in a prefix-sum.
15 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
16 The operation is not assumed to be commutative.
19 """index of the item to output to"""
21 """index of the item the left-hand-side input comes from"""
23 """index of the item the right-hand-side input comes from"""
25 """row in the prefix-sum diagram"""
28 def prefix_sum_ops(item_count
, *, work_efficient
=False, reduce_only
=False):
29 """ Get the associative operations needed to compute a parallel prefix-sum
30 of `item_count` items.
32 The operations aren't assumed to be commutative.
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
43 number of input items.
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
49 True if the work-efficient algorithm should stop after the initial
52 output associative operations.
54 assert isinstance(item_count
, int)
55 # compute the partial sums using a set of binary trees
56 # this is the first half of the work-efficient algorithm and the whole of
57 # the non-work-efficient algorithm.
60 while dist
< item_count
:
61 start
= dist
* 2 - 1 if work_efficient
else dist
62 step
= dist
* 2 if work_efficient
else 1
63 for i
in reversed(range(start
, item_count
, step
)):
64 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
67 if work_efficient
and not reduce_only
:
68 # express all output items in terms of the computed partial sums.
71 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
72 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
77 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
78 """ Compute the parallel prefix-sum of `items`, using associative operator
79 `fn` instead of addition.
81 This has a depth of `O(log(N))` and an operation count of `O(N)` if
82 `work_efficient` is true, otherwise `O(N*log(N))`.
84 The algorithms used are derived from:
85 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
86 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
91 fn: Callable[[_T, _T], _T]
92 Operation to use for the prefix-sum algorithm instead of addition.
93 Assumed to be associative not necessarily commutative.
95 True if the algorithm used should be work-efficient -- has a larger
96 depth (about twice as large) but does only `O(N)` operations total
97 instead of `O(N*log(N))`.
102 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
103 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
114 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
115 sp
=" ", vbar
="|", plus
="⊕",
116 slant
="\\", connect
="●", no_connect
="X",
119 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
123 number of input items.
125 True if the algorithm used should be work-efficient -- has a larger
126 depth (about twice as large) but does only `O(N)` operations total
127 instead of `O(N*log(N))`.
129 character used for blank space
131 character used for a vertical bar
133 character used for the addition operation
135 character used to draw a line from the top left to the bottom right
137 character used to draw a connection between a vertical line and a line
138 going from the center of this character to the bottom right
140 character used to draw two lines crossing but not connecting, the lines
141 are vertical and diagonal from top left to the bottom right
143 amount of padding characters in the output cells.
147 assert isinstance(item_count
, int)
148 assert isinstance(padding
, int)
149 ops_by_row
= defaultdict(set)
150 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
151 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
152 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
153 ops_by_row
[op
.row
].add(op
)
156 return [_Cell(slant
=False, plus
=False, tee
=False)
157 for _
in range(item_count
)]
159 cells
= [blank_row()]
161 for row
in sorted(ops_by_row
.keys()):
162 ops
= ops_by_row
[row
]
163 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
164 cells
.extend(blank_row() for _
in range(max_distance
))
166 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
169 cells
[y
][x
].plus
= True
173 cells
[y
][x
].slant
= True
176 cells
[y
][x
].tee
= True
179 for cells_row
in cells
:
180 row_text
= [[] for y
in range(2 * padding
+ 1)]
181 for cell
in cells_row
:
183 for y
in range(padding
):
185 for x
in range(padding
):
186 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
187 row_text
[y
].append(slant
if is_slant
else sp
)
189 row_text
[y
].append(vbar
)
191 for x
in range(padding
):
192 row_text
[y
].append(sp
)
193 # center left padding
194 for x
in range(padding
):
195 row_text
[padding
].append(sp
)
204 row_text
[padding
].append(center
)
205 # center right padding
206 for x
in range(padding
):
207 row_text
[padding
].append(sp
)
209 for y
in range(padding
+ 1, 2 * padding
+ 1):
210 # bottom left padding
211 for x
in range(padding
):
212 row_text
[y
].append(sp
)
213 # bottom vertical bar
214 row_text
[y
].append(vbar
)
215 # bottom right padding
216 for x
in range(padding
+ 1, 2 * padding
+ 1):
217 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
218 row_text
[y
].append(slant
if is_slant
else sp
)
219 for line
in row_text
:
220 lines
.append("".join(line
))
222 return "\n".join(map(str.rstrip
, lines
))
225 if __name__
== "__main__":
226 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
228 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
230 print(render_prefix_sum_diagram(16, work_efficient
=False))
233 print("the work-efficient algorithm, matches the diagram in wikipedia:")
234 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
236 print(render_prefix_sum_diagram(16, work_efficient
=True))